7.2.5. Example: Image Segmentation using DETR
Sample program demonstrating inference using the DETR trained model on the COCO dataset
Note
The trained model and some source code used in this example have been partially modified or directly sourced from facebookresearch/detr. All of these components are licensed under the Apache License, Version 2.0.
The COCO dataset is licensed under CC BY 4.0, and we exclusively use images licensed under CC BY 2.0.
Execution Method
The first execution performs the following downloads (subsequent runs will skip these steps).
By default, the download location is /tmp/mlsdk_detr_inference/.
Python packages
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference
$ ./run_detr_inference.sh /tmp/mlsdk_detr_inference/coco/val2017/000000000785.jpg
Expected Output
A segmentation result will be saved in the current working directory.
Segmentation result (
./000000000785.png)
Fig. 7.5 Segmentation result using DETR on MN-Core 2
Scripts
1#! /bin/bash
2
3set -eux -o pipefail
4
5EXAMPLE_NAME="mlsdk_detr_inference"
6VENV_DIR=${VENV_DIR:-"/tmp/${EXAMPLE_NAME}/venv"}
7EXTERNAL_DIR=${EXTERNAL_DIR:-"/tmp/${EXAMPLE_NAME}/external"}
8COCO_DIR=${COCO_DIR:-"/tmp/${EXAMPLE_NAME}/coco"}
9OUT_DIR=${OUT_DIR:-"/tmp/${EXAMPLE_NAME}/out"}
10
11CURRENT_DIR=$(realpath $(dirname $0))
12CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../../)
13BUILD_DIR="${CODEGEN_DIR}/build"
14
15### Prepare and source venv/
16
17if [[ ! -d ${VENV_DIR} ]]; then
18 python3 -m venv --system-site-packages ${VENV_DIR}
19 source ${VENV_DIR}/bin/activate
20 pip3 install -r ${CURRENT_DIR}/requirements.txt
21else
22 source ${VENV_DIR}/bin/activate
23fi
24
25### Prepare external/ items
26
27mkdir -p ${EXTERNAL_DIR}
28pushd ${EXTERNAL_DIR}
29if [[ ! -d detr ]]; then
30 git clone https://github.com/facebookresearch/detr.git --depth 1
31fi
32popd
33
34TARGET_FILES=(
35 "models/detr.py"
36 "models/matcher.py"
37)
38
39for REL_PATH in "${TARGET_FILES[@]}"; do
40 BASE_NAME=$(basename "$REL_PATH" .py)
41 PATCH_TARGET="${EXTERNAL_DIR}/detr/${REL_PATH}"
42 PATCH_FILE="${CURRENT_DIR}/patches/${BASE_NAME}.patch"
43 patch --forward --backup -i "$PATCH_FILE" "$PATCH_TARGET" || [ $? -eq 1 ]
44done
45
46cp ${CURRENT_DIR}/lsa.py ${EXTERNAL_DIR}/detr/models/
47
48### Run detr_inference.py
49
50export PYTHONPATH="${EXTERNAL_DIR}/detr${PYTHONPATH:+:${PYTHONPATH}}"
51echo PYTHONPATH
52
53source "${BUILD_DIR}/codegen_pythonpath.sh"
54
55export MNCORE_USE_EXTERNAL_DATA_FORMAT=1
56
57python3 ${CURRENT_DIR}/detr_inference.py ${@}
1import argparse
2from pathlib import Path
3from typing import Any
4
5import torch
6import torchvision.transforms as T
7from detr_eval import DETRSegmToBBoxAttn, MaskHeadPart, run_eval
8from mlsdk import Context, MNDevice
9from PIL import Image
10from utility import apply_toml_defaults
11
12
13def prepare_task_components(args: argparse.Namespace) -> dict[str, Any]:
14 task_components = {}
15
16 # Create model and post processor objs
17 task_components["model"], task_components["postprocessors"] = torch.hub.load(
18 "facebookresearch/detr",
19 "detr_resnet50_panoptic",
20 pretrained=True,
21 return_postprocessor=True,
22 )
23
24 # standard PyTorch mean-std input image normalization
25 transform = T.Compose(
26 [
27 T.Resize((800, 800)),
28 T.ToTensor(),
29 T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
30 ]
31 )
32 im = Image.open(args.img_path)
33 task_components["orig_target_size"] = torch.as_tensor(
34 T.functional.to_tensor(im).shape[-2:]
35 ).unsqueeze(0)
36 task_components["image"] = transform(im).unsqueeze(0)
37
38 return task_components
39
40
41def main(args: argparse.Namespace) -> None:
42 # Pass device info to the Context obj
43 device = MNDevice(args.device_name)
44 context = Context(device)
45 Context.switch_context(context)
46
47 task_components = prepare_task_components(args)
48 task_components["mask_head"] = MaskHeadPart(
49 task_components["model"], num_split=args.num_split
50 )
51 task_components["model"] = DETRSegmToBBoxAttn(task_components["model"])
52
53 run_eval(args, task_components, context)
54
55
56if __name__ == "__main__":
57
58 parser = argparse.ArgumentParser()
59
60 parser.add_argument("img_path", type=Path, help="Path to input image")
61 parser.add_argument("--device", type=str, default="mncore2:auto")
62 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_detr_inference/out")
63 parser.add_argument(
64 "--option_json",
65 type=Path,
66 default="/opt/pfn/pfcomp/codegen/preset_options/O1.json",
67 )
68
69 # load configs from toml file
70 apply_toml_defaults(Path(__file__).resolve().parent / "configs.toml", parser)
71
72 args = parser.parse_args()
73
74 # Set "device" attribute for detr module
75 args.device, args.device_name = "cpu", args.device
76
77 main(args)
1import io
2import os
3from argparse import Namespace
4from pathlib import Path
5from typing import Any
6
7import numpy
8import torch
9from detectron2.data import MetadataCatalog
10from detectron2.utils.visualizer import Visualizer
11from mlsdk import CompiledFunction, Context
12from models.segmentation import DETRsegm
13from panopticapi.utils import rgb2id
14from PIL import Image
15from util.misc import nested_tensor_from_tensor_list
16from utility import compile_fn
17
18
19class DETRSegmToBBoxAttn(torch.nn.Module):
20 def __init__(self, detr_segm: DETRsegm) -> None:
21 super().__init__()
22 self.detr = detr_segm.detr
23 self.bbox_attention = detr_segm.bbox_attention
24
25 def forward(self, sample: torch.Tensor) -> dict[str, torch.Tensor]:
26 out = {}
27
28 sample = nested_tensor_from_tensor_list(sample)
29 features, pos = self.detr.backbone(sample)
30
31 src, mask = features[-1].decompose()
32 src_proj = self.detr.input_proj(src)
33 hs, memory = self.detr.transformer(
34 src_proj, mask, self.detr.query_embed.weight, pos[-1]
35 )
36
37 outputs_class = self.detr.class_embed(hs)
38 outputs_coord = self.detr.bbox_embed(hs).sigmoid()
39
40 # FIXME h_boxes takes the last one computed, keep this in mind
41 bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
42
43 out.update(
44 feat2=features[2].tensors,
45 feat1=features[1].tensors,
46 feat0=features[0].tensors,
47 src_proj=src_proj,
48 pred_logits=outputs_class[-1],
49 pred_boxes=outputs_coord[-1],
50 bbox_mask=bbox_mask,
51 )
52
53 return out
54
55
56class MaskHeadPart(torch.nn.Module):
57 def __init__(self, detr_segm: DETRsegm, num_split: int = 100) -> None:
58 super().__init__()
59
60 self.mask_head = detr_segm.mask_head
61 self.num_queries = detr_segm.detr.num_queries # 100
62 self.num_split = num_split
63
64 def forward(
65 self, x: torch.Tensor, bbox_mask: torch.Tensor, fpns: list[torch.Tensor]
66 ) -> torch.Tensor:
67 seg_mask = self.mask_head(x, bbox_mask, fpns)
68 output_seg_mask = seg_mask.view(
69 x.shape[0],
70 self.num_queries // self.num_split,
71 seg_mask.shape[-2],
72 seg_mask.shape[-1],
73 )
74
75 return output_seg_mask
76
77
78def visualize_prediction(
79 args: Namespace,
80 task_components: dict[str, Any],
81 outputs: dict[str, torch.Tensor],
82) -> None:
83
84 image = numpy.array(Image.open(args.img_path))[:, :, ::-1]
85 result = task_components["postprocessors"](outputs, [image.shape[:2]])[0]
86
87 # Panoptic predictions are stored in a special format png
88 panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
89
90 # We convert the png into an segment id map
91 panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
92 panoptic_seg = torch.from_numpy(rgb2id(panoptic_seg))
93
94 # Detectron2 uses a different numbering of coco classes,
95 # here we convert the class ids accordingly
96 meta = MetadataCatalog.get("coco_2017_val_panoptic_separated")
97 for info in result["segments_info"]:
98 c = info["category_id"]
99 if info["isthing"]:
100 info["category_id"] = meta.thing_dataset_id_to_contiguous_id[c]
101 else:
102 info["category_id"] = meta.stuff_dataset_id_to_contiguous_id[c]
103
104 # Finally we visualize the prediction
105 v = Visualizer(image, meta)
106 v._default_font_size = 20
107 v = v.draw_panoptic_seg_predictions(panoptic_seg, result["segments_info"])
108 v.save(os.path.join(args.outdir, f"{Path(args.img_path).stem}.png"))
109
110
111def compile_eval_fn(
112 args: Namespace,
113 task_components: dict[str, Any],
114 context: Context,
115) -> tuple[CompiledFunction, CompiledFunction, dict[str, torch.Tensor]]:
116
117 # in case using a mncore2 backend, the DETR model is separated
118 # into two halves to avoid the LM oom and the reshape errors:
119 # the first half calculates the bbox_masks,
120 # and the second half calculates the remaining MaskHead part
121 def eval_to_bbox_or_full(
122 sample: dict[str, torch.Tensor],
123 ) -> dict[str, torch.Tensor]: # eval_all_or_bbox
124 with torch.no_grad():
125 return task_components["model"](sample["image"])
126
127 def eval_mask_head(sample: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
128 with torch.no_grad():
129 pred_mask = task_components["mask_head"](
130 sample["src_proj"],
131 sample["bbox_mask"],
132 [sample["feat2"], sample["feat1"], sample["feat0"]],
133 )
134
135 return {"pred_mask": pred_mask}
136
137 eval_fn = eval_to_bbox_or_full
138 mask_fn = eval_mask_head
139
140 # For compilation for MLSDK/MN-Core 2
141 sample = {"image": task_components["image"]}
142
143 eval_fn = compile_fn(
144 context,
145 eval_fn,
146 task_components["model"],
147 sample,
148 os.path.join(args.outdir, "to_bbox"),
149 model_name="detr_to_bbox",
150 option_json=args.option_json,
151 )
152
153 sample.update(
154 src_proj=torch.randn(1, 256, 25, 25),
155 bbox_mask=torch.rand(
156 1, task_components["mask_head"].num_queries // args.num_split, 8, 25, 25
157 ), # originally, 1, 100, 8, 25, 25
158 feat2=torch.randn(1, 1024, 50, 50),
159 feat1=torch.randn(1, 512, 100, 100),
160 feat0=torch.randn(1, 256, 200, 200),
161 )
162
163 mask_fn = compile_fn(
164 context,
165 mask_fn,
166 task_components["mask_head"],
167 sample,
168 os.path.join(args.outdir, "mask_head"),
169 model_name="detr_mask_head",
170 option_json=args.option_json,
171 )
172
173 return eval_fn, mask_fn
174
175
176@torch.no_grad()
177def evaluate(
178 args: Namespace,
179 eval_fn: CompiledFunction,
180 mask_fn: CompiledFunction,
181 task_components: dict[str, Any],
182) -> dict[str, torch.Tensor]:
183
184 sample = {"image": task_components["image"]}
185
186 outputs = eval_fn(sample)
187
188 num_queries = outputs["bbox_mask"].shape[1]
189 bbox_masks = outputs["bbox_mask"].split(num_queries // args.num_split, dim=1)
190 sample.update(
191 src_proj=outputs["src_proj"],
192 feat2=outputs["feat2"],
193 feat1=outputs["feat1"],
194 feat0=outputs["feat0"],
195 )
196
197 pred_masks = []
198 for bbox_mask in bbox_masks:
199 sample.update(bbox_mask=bbox_mask)
200 pred_masks.append(mask_fn(sample)["pred_mask"].cpu())
201
202 pred_masks = torch.cat(pred_masks, dim=0)
203 outputs = {
204 "pred_logits": outputs["pred_logits"],
205 "pred_boxes": outputs["pred_boxes"],
206 "pred_masks": pred_masks.view(
207 1, num_queries, pred_masks.shape[-2], pred_masks.shape[-1]
208 ),
209 }
210
211 return outputs
212
213
214def run_eval(
215 args: Namespace, task_components: dict[str, Any], context: Context
216) -> None:
217
218 task_components["model"].eval()
219 task_components["mask_head"].eval()
220 task_components["postprocessors"].eval()
221
222 eval_fn, mask_fn = compile_eval_fn(args, task_components, context)
223 outputs = evaluate(args, eval_fn, mask_fn, task_components)
224 visualize_prediction(args, task_components, outputs)
1import argparse
2import os
3import sys
4from collections.abc import Callable
5from pathlib import Path
6
7import tomllib
8import torch
9from mlsdk import (
10 CompiledFunction,
11 Context,
12 get_tensor_name,
13 set_tensor_name_in_module,
14 storage,
15)
16
17
18def register_model(
19 context: Context,
20 name: str,
21 model: torch.nn.Module,
22) -> None:
23 if (
24 get_tensor_name(next(model.parameters())) is None
25 ): # in case the model obj isn't registered to the context
26 set_tensor_name_in_module(model, name)
27 for p in model.parameters():
28 context.register_param(p)
29 for b in model.buffers():
30 context.register_buffer(b)
31
32
33def compile_fn( # noqa: CFQ002
34 context: Context,
35 target_fn: Callable[
36 [
37 dict[str, torch.Tensor],
38 ],
39 dict[str, torch.Tensor],
40 ], # compiled fn
41 model: torch.nn.Module,
42 sample_input: dict[str, torch.Tensor],
43 outdir: str = "/tmp/example_output",
44 model_name: str = "example",
45 option_json: Path | None = None,
46) -> CompiledFunction:
47
48 if option_json is None:
49 option_json = Path("/opt/pfn/pfcomp/codegen/preset_options/O1.json")
50
51 compile_options = {"option_json": str(option_json)}
52
53 compile_args = {
54 "function": target_fn,
55 "inputs": sample_input,
56 "options": compile_options,
57 }
58
59 codegen_base_dir = storage.path(outdir)
60 compile_args["codegen_dir"] = codegen_base_dir / model_name
61
62 register_model(context, "model", model)
63
64 return context.compile(**compile_args)
65
66
67# for type hint of the configs from toml
68class TomlValue:
69 str | int | float | bool | list["TomlValue"] | dict[str, "TomlValue"]
70
71
72class TomlDict:
73 dict[str, TomlValue]
74
75
76def read_configs_from_toml(
77 toml_path: str,
78) -> TomlDict:
79
80 configs_dict = None
81 with open(toml_path, mode="rb") as f:
82 configs_dict = tomllib.load(f)
83
84 return configs_dict
85
86
87def apply_toml_defaults(
88 configs: TomlDict | str | os.PathLike,
89 parser: argparse.ArgumentParser,
90) -> None:
91
92 if isinstance(configs, dict):
93 for k, v in configs.items():
94 if isinstance(v, dict): # in case v is (nested) dict
95 apply_toml_defaults(v, parser)
96 else:
97 parser.add_argument(f"--{k}", default=v, type=type(v))
98 elif isinstance(configs, str) or isinstance(configs, os.PathLike):
99 configs_dict = read_configs_from_toml(configs)
100
101 apply_toml_defaults(configs_dict, parser)
102 else:
103 sys.exit("")
1Cython==3.1.3
2matplotlib==3.10.5
3pycocotools==2.0.10
4scipy==1.16.1
5git+https://github.com/cocodataset/panopticapi.git
6git+https://github.com/facebookresearch/detectron2.git
1title = "detr_inference"
2
3[model.backbone]
4backbone = "resnet50"
5dilation = false
6
7[model.transformer]
8enc_layers = 6
9dec_layers = 6
10dim_feedforward = 2048
11hidden_dim = 256
12dropout = 0.1
13nheads = 8
14num_queries = 100
15pre_norm = false
16
17[model.loss.matcher]
18set_cost_class = 1
19set_cost_bbox = 5
20set_cost_giou = 2
21
22[model.loss.loss_coefficients]
23eos_coef = 0.1
24mask_loss_coef = 1
25dice_loss_coef = 1
26bbox_loss_coef = 5
27giou_loss_coef = 2
28
29
30[data]
31max_num_segms = 90 # max num of segments in the ground truth labels
32num_workers = 2
33
34
35[training]
36lr = 1e-4
37lr_backbone = 1e-5
38weight_decay = 1e-4
39epochs = 300
40lr_drop = 200 # epoch at which lr drops
41clip_max_norm = 0.1 # gradient clipping max norm
42frozen_weights = "" # path to the pretrained model. if set, only the mask head will be trained
43resume = "" # checkpoint to resume from
44masks = true # create DETRSegm obj
45aux_loss = false
46
47
48[evaluation]
49num_split = 100
50position_embedding = "sine"
51dataset_file = "coco_panoptic"
52
53
54[misc]
55batch_size = 1