7.2.5. Example: Image Segmentation using DETR

Sample program demonstrating inference using the DETR trained model on the COCO dataset

Note

The trained model and some source code used in this example have been partially modified or directly sourced from facebookresearch/detr. All of these components are licensed under the Apache License, Version 2.0.

The COCO dataset is licensed under CC BY 4.0, and we exclusively use images licensed under CC BY 2.0.

Execution Method

The first execution performs the following downloads (subsequent runs will skip these steps). By default, the download location is /tmp/mlsdk_detr_inference/.

$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference
$ ./run_detr_inference.sh /tmp/mlsdk_detr_inference/coco/val2017/000000000785.jpg

Expected Output

A segmentation result will be saved in the current working directory.

  • Segmentation result (./000000000785.png)

Segmentation result using DETR on MN-Core 2

Fig. 7.5 Segmentation result using DETR on MN-Core 2

Scripts

Listing 7.23 /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference/run_detr_inference.sh
 1#! /bin/bash
 2
 3set -eux -o pipefail
 4
 5EXAMPLE_NAME="mlsdk_detr_inference"
 6VENV_DIR=${VENV_DIR:-"/tmp/${EXAMPLE_NAME}/venv"}
 7EXTERNAL_DIR=${EXTERNAL_DIR:-"/tmp/${EXAMPLE_NAME}/external"}
 8COCO_DIR=${COCO_DIR:-"/tmp/${EXAMPLE_NAME}/coco"}
 9OUT_DIR=${OUT_DIR:-"/tmp/${EXAMPLE_NAME}/out"}
10
11CURRENT_DIR=$(realpath $(dirname $0))
12CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../../)
13BUILD_DIR="${CODEGEN_DIR}/build"
14
15### Prepare and source venv/
16
17if [[ ! -d ${VENV_DIR} ]]; then
18    python3 -m venv --system-site-packages ${VENV_DIR}
19    source ${VENV_DIR}/bin/activate
20    pip3 install -r ${CURRENT_DIR}/requirements.txt
21else
22    source ${VENV_DIR}/bin/activate
23fi
24
25### Prepare external/ items
26
27mkdir -p ${EXTERNAL_DIR}
28pushd ${EXTERNAL_DIR}
29if [[ ! -d detr ]]; then
30    git clone https://github.com/facebookresearch/detr.git --depth 1
31fi
32popd
33
34TARGET_FILES=(
35    "models/detr.py"
36    "models/matcher.py"
37)
38
39for REL_PATH in "${TARGET_FILES[@]}"; do
40    BASE_NAME=$(basename "$REL_PATH" .py)
41    PATCH_TARGET="${EXTERNAL_DIR}/detr/${REL_PATH}"
42    PATCH_FILE="${CURRENT_DIR}/patches/${BASE_NAME}.patch"
43    patch --forward --backup -i "$PATCH_FILE" "$PATCH_TARGET" || [ $? -eq 1 ]
44done
45
46cp ${CURRENT_DIR}/lsa.py ${EXTERNAL_DIR}/detr/models/
47
48### Run detr_inference.py
49
50export PYTHONPATH="${EXTERNAL_DIR}/detr${PYTHONPATH:+:${PYTHONPATH}}"
51echo PYTHONPATH
52
53source "${BUILD_DIR}/codegen_pythonpath.sh"
54
55export MNCORE_USE_EXTERNAL_DATA_FORMAT=1
56
57python3 ${CURRENT_DIR}/detr_inference.py ${@}
Listing 7.24 /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference/detr_inference.py
 1import argparse
 2from pathlib import Path
 3from typing import Any
 4
 5import torch
 6import torchvision.transforms as T
 7from detr_eval import DETRSegmToBBoxAttn, MaskHeadPart, run_eval
 8from mlsdk import Context, MNDevice
 9from PIL import Image
10from utility import apply_toml_defaults
11
12
13def prepare_task_components(args: argparse.Namespace) -> dict[str, Any]:
14    task_components = {}
15
16    # Create model and post processor objs
17    task_components["model"], task_components["postprocessors"] = torch.hub.load(
18        "facebookresearch/detr",
19        "detr_resnet50_panoptic",
20        pretrained=True,
21        return_postprocessor=True,
22    )
23
24    # standard PyTorch mean-std input image normalization
25    transform = T.Compose(
26        [
27            T.Resize((800, 800)),
28            T.ToTensor(),
29            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
30        ]
31    )
32    im = Image.open(args.img_path)
33    task_components["orig_target_size"] = torch.as_tensor(
34        T.functional.to_tensor(im).shape[-2:]
35    ).unsqueeze(0)
36    task_components["image"] = transform(im).unsqueeze(0)
37
38    return task_components
39
40
41def main(args: argparse.Namespace) -> None:
42    # Pass device info to the Context obj
43    device = MNDevice(args.device_name)
44    context = Context(device)
45    Context.switch_context(context)
46
47    task_components = prepare_task_components(args)
48    task_components["mask_head"] = MaskHeadPart(
49        task_components["model"], num_split=args.num_split
50    )
51    task_components["model"] = DETRSegmToBBoxAttn(task_components["model"])
52
53    run_eval(args, task_components, context)
54
55
56if __name__ == "__main__":
57
58    parser = argparse.ArgumentParser()
59
60    parser.add_argument("img_path", type=Path, help="Path to input image")
61    parser.add_argument("--device", type=str, default="mncore2:auto")
62    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_detr_inference/out")
63    parser.add_argument(
64        "--option_json",
65        type=Path,
66        default="/opt/pfn/pfcomp/codegen/preset_options/O1.json",
67    )
68
69    # load configs from toml file
70    apply_toml_defaults(Path(__file__).resolve().parent / "configs.toml", parser)
71
72    args = parser.parse_args()
73
74    # Set "device" attribute for detr module
75    args.device, args.device_name = "cpu", args.device
76
77    main(args)
Listing 7.25 /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference/detr_eval.py
  1import io
  2import os
  3from argparse import Namespace
  4from pathlib import Path
  5from typing import Any
  6
  7import numpy
  8import torch
  9from detectron2.data import MetadataCatalog
 10from detectron2.utils.visualizer import Visualizer
 11from mlsdk import CompiledFunction, Context
 12from models.segmentation import DETRsegm
 13from panopticapi.utils import rgb2id
 14from PIL import Image
 15from util.misc import nested_tensor_from_tensor_list
 16from utility import compile_fn
 17
 18
 19class DETRSegmToBBoxAttn(torch.nn.Module):
 20    def __init__(self, detr_segm: DETRsegm) -> None:
 21        super().__init__()
 22        self.detr = detr_segm.detr
 23        self.bbox_attention = detr_segm.bbox_attention
 24
 25    def forward(self, sample: torch.Tensor) -> dict[str, torch.Tensor]:
 26        out = {}
 27
 28        sample = nested_tensor_from_tensor_list(sample)
 29        features, pos = self.detr.backbone(sample)
 30
 31        src, mask = features[-1].decompose()
 32        src_proj = self.detr.input_proj(src)
 33        hs, memory = self.detr.transformer(
 34            src_proj, mask, self.detr.query_embed.weight, pos[-1]
 35        )
 36
 37        outputs_class = self.detr.class_embed(hs)
 38        outputs_coord = self.detr.bbox_embed(hs).sigmoid()
 39
 40        # FIXME h_boxes takes the last one computed, keep this in mind
 41        bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
 42
 43        out.update(
 44            feat2=features[2].tensors,
 45            feat1=features[1].tensors,
 46            feat0=features[0].tensors,
 47            src_proj=src_proj,
 48            pred_logits=outputs_class[-1],
 49            pred_boxes=outputs_coord[-1],
 50            bbox_mask=bbox_mask,
 51        )
 52
 53        return out
 54
 55
 56class MaskHeadPart(torch.nn.Module):
 57    def __init__(self, detr_segm: DETRsegm, num_split: int = 100) -> None:
 58        super().__init__()
 59
 60        self.mask_head = detr_segm.mask_head
 61        self.num_queries = detr_segm.detr.num_queries  # 100
 62        self.num_split = num_split
 63
 64    def forward(
 65        self, x: torch.Tensor, bbox_mask: torch.Tensor, fpns: list[torch.Tensor]
 66    ) -> torch.Tensor:
 67        seg_mask = self.mask_head(x, bbox_mask, fpns)
 68        output_seg_mask = seg_mask.view(
 69            x.shape[0],
 70            self.num_queries // self.num_split,
 71            seg_mask.shape[-2],
 72            seg_mask.shape[-1],
 73        )
 74
 75        return output_seg_mask
 76
 77
 78def visualize_prediction(
 79    args: Namespace,
 80    task_components: dict[str, Any],
 81    outputs: dict[str, torch.Tensor],
 82) -> None:
 83
 84    image = numpy.array(Image.open(args.img_path))[:, :, ::-1]
 85    result = task_components["postprocessors"](outputs, [image.shape[:2]])[0]
 86
 87    # Panoptic predictions are stored in a special format png
 88    panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
 89
 90    # We convert the png into an segment id map
 91    panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
 92    panoptic_seg = torch.from_numpy(rgb2id(panoptic_seg))
 93
 94    # Detectron2 uses a different numbering of coco classes,
 95    # here we convert the class ids accordingly
 96    meta = MetadataCatalog.get("coco_2017_val_panoptic_separated")
 97    for info in result["segments_info"]:
 98        c = info["category_id"]
 99        if info["isthing"]:
100            info["category_id"] = meta.thing_dataset_id_to_contiguous_id[c]
101        else:
102            info["category_id"] = meta.stuff_dataset_id_to_contiguous_id[c]
103
104    # Finally we visualize the prediction
105    v = Visualizer(image, meta)
106    v._default_font_size = 20
107    v = v.draw_panoptic_seg_predictions(panoptic_seg, result["segments_info"])
108    v.save(os.path.join(args.outdir, f"{Path(args.img_path).stem}.png"))
109
110
111def compile_eval_fn(
112    args: Namespace,
113    task_components: dict[str, Any],
114    context: Context,
115) -> tuple[CompiledFunction, CompiledFunction, dict[str, torch.Tensor]]:
116
117    # in case using a mncore2 backend, the DETR model is separated
118    # into two halves to avoid the LM oom and the reshape errors:
119    # the first half calculates the bbox_masks,
120    # and the second half calculates the remaining MaskHead part
121    def eval_to_bbox_or_full(
122        sample: dict[str, torch.Tensor],
123    ) -> dict[str, torch.Tensor]:  # eval_all_or_bbox
124        with torch.no_grad():
125            return task_components["model"](sample["image"])
126
127    def eval_mask_head(sample: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
128        with torch.no_grad():
129            pred_mask = task_components["mask_head"](
130                sample["src_proj"],
131                sample["bbox_mask"],
132                [sample["feat2"], sample["feat1"], sample["feat0"]],
133            )
134
135        return {"pred_mask": pred_mask}
136
137    eval_fn = eval_to_bbox_or_full
138    mask_fn = eval_mask_head
139
140    # For compilation for MLSDK/MN-Core 2
141    sample = {"image": task_components["image"]}
142
143    eval_fn = compile_fn(
144        context,
145        eval_fn,
146        task_components["model"],
147        sample,
148        os.path.join(args.outdir, "to_bbox"),
149        model_name="detr_to_bbox",
150        option_json=args.option_json,
151    )
152
153    sample.update(
154        src_proj=torch.randn(1, 256, 25, 25),
155        bbox_mask=torch.rand(
156            1, task_components["mask_head"].num_queries // args.num_split, 8, 25, 25
157        ),  # originally, 1, 100, 8, 25, 25
158        feat2=torch.randn(1, 1024, 50, 50),
159        feat1=torch.randn(1, 512, 100, 100),
160        feat0=torch.randn(1, 256, 200, 200),
161    )
162
163    mask_fn = compile_fn(
164        context,
165        mask_fn,
166        task_components["mask_head"],
167        sample,
168        os.path.join(args.outdir, "mask_head"),
169        model_name="detr_mask_head",
170        option_json=args.option_json,
171    )
172
173    return eval_fn, mask_fn
174
175
176@torch.no_grad()
177def evaluate(
178    args: Namespace,
179    eval_fn: CompiledFunction,
180    mask_fn: CompiledFunction,
181    task_components: dict[str, Any],
182) -> dict[str, torch.Tensor]:
183
184    sample = {"image": task_components["image"]}
185
186    outputs = eval_fn(sample)
187
188    num_queries = outputs["bbox_mask"].shape[1]
189    bbox_masks = outputs["bbox_mask"].split(num_queries // args.num_split, dim=1)
190    sample.update(
191        src_proj=outputs["src_proj"],
192        feat2=outputs["feat2"],
193        feat1=outputs["feat1"],
194        feat0=outputs["feat0"],
195    )
196
197    pred_masks = []
198    for bbox_mask in bbox_masks:
199        sample.update(bbox_mask=bbox_mask)
200        pred_masks.append(mask_fn(sample)["pred_mask"].cpu())
201
202    pred_masks = torch.cat(pred_masks, dim=0)
203    outputs = {
204        "pred_logits": outputs["pred_logits"],
205        "pred_boxes": outputs["pred_boxes"],
206        "pred_masks": pred_masks.view(
207            1, num_queries, pred_masks.shape[-2], pred_masks.shape[-1]
208        ),
209    }
210
211    return outputs
212
213
214def run_eval(
215    args: Namespace, task_components: dict[str, Any], context: Context
216) -> None:
217
218    task_components["model"].eval()
219    task_components["mask_head"].eval()
220    task_components["postprocessors"].eval()
221
222    eval_fn, mask_fn = compile_eval_fn(args, task_components, context)
223    outputs = evaluate(args, eval_fn, mask_fn, task_components)
224    visualize_prediction(args, task_components, outputs)
Listing 7.26 /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference/utility.py
  1import argparse
  2import os
  3import sys
  4from collections.abc import Callable
  5from pathlib import Path
  6
  7import tomllib
  8import torch
  9from mlsdk import (
 10    CompiledFunction,
 11    Context,
 12    get_tensor_name,
 13    set_tensor_name_in_module,
 14    storage,
 15)
 16
 17
 18def register_model(
 19    context: Context,
 20    name: str,
 21    model: torch.nn.Module,
 22) -> None:
 23    if (
 24        get_tensor_name(next(model.parameters())) is None
 25    ):  # in case the model obj isn't registered to the context
 26        set_tensor_name_in_module(model, name)
 27        for p in model.parameters():
 28            context.register_param(p)
 29        for b in model.buffers():
 30            context.register_buffer(b)
 31
 32
 33def compile_fn(  # noqa: CFQ002
 34    context: Context,
 35    target_fn: Callable[
 36        [
 37            dict[str, torch.Tensor],
 38        ],
 39        dict[str, torch.Tensor],
 40    ],  # compiled fn
 41    model: torch.nn.Module,
 42    sample_input: dict[str, torch.Tensor],
 43    outdir: str = "/tmp/example_output",
 44    model_name: str = "example",
 45    option_json: Path | None = None,
 46) -> CompiledFunction:
 47
 48    if option_json is None:
 49        option_json = Path("/opt/pfn/pfcomp/codegen/preset_options/O1.json")
 50
 51    compile_options = {"option_json": str(option_json)}
 52
 53    compile_args = {
 54        "function": target_fn,
 55        "inputs": sample_input,
 56        "options": compile_options,
 57    }
 58
 59    codegen_base_dir = storage.path(outdir)
 60    compile_args["codegen_dir"] = codegen_base_dir / model_name
 61
 62    register_model(context, "model", model)
 63
 64    return context.compile(**compile_args)
 65
 66
 67# for type hint of the configs from toml
 68class TomlValue:
 69    str | int | float | bool | list["TomlValue"] | dict[str, "TomlValue"]
 70
 71
 72class TomlDict:
 73    dict[str, TomlValue]
 74
 75
 76def read_configs_from_toml(
 77    toml_path: str,
 78) -> TomlDict:
 79
 80    configs_dict = None
 81    with open(toml_path, mode="rb") as f:
 82        configs_dict = tomllib.load(f)
 83
 84    return configs_dict
 85
 86
 87def apply_toml_defaults(
 88    configs: TomlDict | str | os.PathLike,
 89    parser: argparse.ArgumentParser,
 90) -> None:
 91
 92    if isinstance(configs, dict):
 93        for k, v in configs.items():
 94            if isinstance(v, dict):  # in case v is (nested) dict
 95                apply_toml_defaults(v, parser)
 96            else:
 97                parser.add_argument(f"--{k}", default=v, type=type(v))
 98    elif isinstance(configs, str) or isinstance(configs, os.PathLike):
 99        configs_dict = read_configs_from_toml(configs)
100
101        apply_toml_defaults(configs_dict, parser)
102    else:
103        sys.exit("")
Listing 7.27 /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference/requirements.txt
1Cython==3.1.3
2matplotlib==3.10.5
3pycocotools==2.0.10
4scipy==1.16.1
5git+https://github.com/cocodataset/panopticapi.git
6git+https://github.com/facebookresearch/detectron2.git
Listing 7.28 /opt/pfn/pfcomp/codegen/MLSDK/examples/detr_inference/configs.toml
 1title = "detr_inference"
 2
 3[model.backbone]
 4backbone = "resnet50"
 5dilation = false
 6
 7[model.transformer]
 8enc_layers      = 6
 9dec_layers      = 6
10dim_feedforward = 2048
11hidden_dim      = 256
12dropout         = 0.1
13nheads          = 8
14num_queries     = 100
15pre_norm        = false
16
17[model.loss.matcher]
18set_cost_class = 1
19set_cost_bbox  = 5
20set_cost_giou  = 2
21
22[model.loss.loss_coefficients]
23eos_coef       = 0.1
24mask_loss_coef = 1
25dice_loss_coef = 1
26bbox_loss_coef = 5
27giou_loss_coef = 2
28
29
30[data]
31max_num_segms      = 90                # max num of segments in the ground truth labels
32num_workers        = 2
33
34
35[training]
36lr             = 1e-4
37lr_backbone    = 1e-5
38weight_decay   = 1e-4
39epochs         = 300
40lr_drop        = 200   # epoch at which lr drops
41clip_max_norm  = 0.1   # gradient clipping max norm
42frozen_weights = ""    # path to the pretrained model. if set, only the mask head will be trained
43resume         = ""    # checkpoint to resume from
44masks          = true  # create DETRSegm obj
45aux_loss       = false
46
47
48[evaluation]
49num_split          = 100
50position_embedding = "sine"
51dataset_file       = "coco_panoptic"
52
53
54[misc]
55batch_size = 1