8.1.5. Example: Using non-dict inputs and outputs with compile_automap
compile_automap can use input/output types other than dictionaries. In this sample program, it receives torch.ones(3, 4) through inputs such as tuple, list, namedtuple, and dataclass, performs an addition (add) on MN-Core 2 processor, and directly returns the result as output. compile_automap supports these types (including nested structures) as long as they contain torch.Tensor elements.
Execution Method
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 add_automap.py
Expected Output
The
codegen_dircontaining the compiled results (/tmp/add_many_tensors)The computed results
result_on_cpu=tensor([[4., 4., 4., 4.],
[4., 4., 4., 4.],
[4., 4., 4., 4.]])
Related Links
Sample Program
1from dataclasses import dataclass
2from typing import List, NamedTuple, Tuple
3
4import torch
5from mlsdk import (
6 Context,
7 MNDevice,
8 register_pytree_dataclass,
9 register_pytree_namedtuple,
10 storage,
11)
12
13
14@register_pytree_namedtuple
15class NT(NamedTuple):
16 x: torch.Tensor
17
18
19@register_pytree_dataclass
20@dataclass
21class DC:
22 x: torch.Tensor
23
24
25def run_add():
26 device = MNDevice("mncore2:auto")
27 context = Context(device)
28 Context.switch_context(context)
29
30 def add(
31 arg1: Tuple[torch.Tensor],
32 arg2: NT,
33 *,
34 kwarg_ls: List[torch.Tensor],
35 kwarg_dc: DC,
36 ) -> torch.Tensor:
37 return arg1[0] + arg2.x + kwarg_ls[0] + kwarg_dc.x
38
39 arg1 = (torch.randn(3, 4),)
40 arg2 = NT(x=torch.randn(3, 4))
41 kwarg_ls = [torch.randn(3, 4)]
42 kwarg_dc = DC(x=torch.randn(3, 4))
43
44 compiled_add = context.compile_automap(
45 add,
46 (arg1, arg2),
47 {"kwarg_ls": kwarg_ls, "kwarg_dc": kwarg_dc},
48 storage.path("/tmp/add_many_tensors"),
49 options={"float_dtype": "float"},
50 )
51 result = compiled_add(
52 (torch.ones(3, 4),),
53 NT(x=torch.ones(3, 4)),
54 kwarg_ls=[torch.ones(3, 4)],
55 kwarg_dc=DC(x=torch.ones(3, 4)),
56 )
57 result_on_cpu = result.cpu()
58 print(f"{result_on_cpu=}")
59 assert torch.allclose(result_on_cpu, torch.ones(3, 4) * 4)
60
61
62if __name__ == "__main__":
63 run_add()