8.1.5. Example: Using non-dict inputs and outputs with compile_automap

compile_automap can use input/output types other than dictionaries. In this sample program, it receives torch.ones(3, 4) through inputs such as tuple, list, namedtuple, and dataclass, performs an addition (add) on MN-Core 2 processor, and directly returns the result as output. compile_automap supports these types (including nested structures) as long as they contain torch.Tensor elements.

Execution Method

$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 add_automap.py

Expected Output

  • The codegen_dir containing the compiled results (/tmp/add_many_tensors)

  • The computed results

result_on_cpu=tensor([[4., 4., 4., 4.],
        [4., 4., 4., 4.],
        [4., 4., 4., 4.]])

Related Links

Sample Program

Listing 8.5 /opt/pfn/pfcomp/codegen/MLSDK/examples/add_automap.py
 1from dataclasses import dataclass
 2from typing import List, NamedTuple, Tuple
 3
 4import torch
 5from mlsdk import (
 6    Context,
 7    MNDevice,
 8    register_pytree_dataclass,
 9    register_pytree_namedtuple,
10    storage,
11)
12
13
14@register_pytree_namedtuple
15class NT(NamedTuple):
16    x: torch.Tensor
17
18
19@register_pytree_dataclass
20@dataclass
21class DC:
22    x: torch.Tensor
23
24
25def run_add():
26    device = MNDevice("mncore2:auto")
27    context = Context(device)
28    Context.switch_context(context)
29
30    def add(
31        arg1: Tuple[torch.Tensor],
32        arg2: NT,
33        *,
34        kwarg_ls: List[torch.Tensor],
35        kwarg_dc: DC,
36    ) -> torch.Tensor:
37        return arg1[0] + arg2.x + kwarg_ls[0] + kwarg_dc.x
38
39    arg1 = (torch.randn(3, 4),)
40    arg2 = NT(x=torch.randn(3, 4))
41    kwarg_ls = [torch.randn(3, 4)]
42    kwarg_dc = DC(x=torch.randn(3, 4))
43
44    compiled_add = context.compile_automap(
45        add,
46        (arg1, arg2),
47        {"kwarg_ls": kwarg_ls, "kwarg_dc": kwarg_dc},
48        storage.path("/tmp/add_many_tensors"),
49        options={"float_dtype": "float"},
50    )
51    result = compiled_add(
52        (torch.ones(3, 4),),
53        NT(x=torch.ones(3, 4)),
54        kwarg_ls=[torch.ones(3, 4)],
55        kwarg_dc=DC(x=torch.ones(3, 4)),
56    )
57    result_on_cpu = result.cpu()
58    print(f"{result_on_cpu=}")
59    assert torch.allclose(result_on_cpu, torch.ones(3, 4) * 4)
60
61
62if __name__ == "__main__":
63    run_add()