8.1.2. Example: MNIST on MN-Core 2
MNIST データセットを対象に、MN-Core 2 上で学習と推論を行うサンプルプログラム
学習結果は --outdir に指定された先の checkpoint.pt ファイルに保存されます (デフォルトでは /tmp/mlsdk_mnist/checkpoint.pt)。
実行方法
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist.py
想定出力
学習中のログ
epoch 0, iter 0, loss 2.3125
epoch 0, iter 100, loss 0.6226431969368814
...
epoch 9, iter 900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
推論結果
Correct: 9609 / 10000. Accuracy: 0.9609
関連リンク
サンプルプログラム
1import argparse
2import random
3from pathlib import Path
4from typing import Mapping, Optional
5
6import numpy as np
7import torch
8from mlsdk import (
9 Context,
10 MNCoreSGD,
11 MNDevice,
12 set_buffer_name_in_optimizer,
13 set_tensor_name_in_module,
14 storage,
15)
16from mnist_common import MNCoreClassifier, mnist_loaders
17
18torch.manual_seed(0)
19random.seed(0)
20np.random.seed(0)
21
22
23def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
24 batch_size = 64
25 eval_batch_size = 125
26
27 device = MNDevice(device_str)
28 context = Context(device)
29 Context.switch_context(context)
30
31 train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
32
33 model_with_loss_fn = MNCoreClassifier()
34 model_with_loss_fn.train()
35 set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
36 for p in model_with_loss_fn.parameters():
37 context.register_param(p)
38 for b in model_with_loss_fn.buffers():
39 context.register_buffer(b)
40
41 optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
42 set_buffer_name_in_optimizer(optimizer, "optimizer")
43 context.register_optimizer_buffers(optimizer)
44
45 def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
46 x = inp["x"]
47 t = inp["t"]
48 optimizer.zero_grad()
49 output = model_with_loss_fn(x, t)
50 loss = output["loss"]
51 loss.backward()
52 optimizer.step()
53 return {"loss": loss}
54
55 compile_options = {}
56 if option_json_path is not None:
57 compile_options["option_json"] = str(option_json_path)
58
59 sample = next(iter(train_loader))
60 compiled_train_step = context.compile(
61 train_step,
62 sample,
63 storage.path(outdir) / "train_step",
64 options=compile_options,
65 )
66
67 for epoch in range(10):
68 loss = 0.0
69 for i, sample in enumerate(train_loader):
70 curr_loss = compiled_train_step(sample)["loss"].item()
71 loss += (curr_loss - loss) / (i + 1)
72 if i % 100 == 0:
73 print(f"epoch {epoch}, iter {i:4}, loss {loss}")
74 print(f"epoch {epoch}, loss {loss}")
75
76 context.synchronize()
77
78 torch.save(
79 {
80 "model_state_dict": model_with_loss_fn.state_dict(),
81 "optim_state_dict": optimizer.state_dict(),
82 },
83 storage.path(outdir) / "checkpoint.pt",
84 )
85
86 model_with_loss_fn.eval()
87
88 def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
89 x = inp["x"]
90 t = inp["t"]
91 output = model_with_loss_fn(x, t)
92 y = output["y"]
93 _, predicted = torch.max(y, 1)
94 correct = (predicted == t).sum()
95 return {"correct": correct}
96
97 sample = next(iter(eval_loader))
98 compiled_eval_step = context.compile(
99 eval_step,
100 sample,
101 storage.path(outdir) / "eval_step",
102 options=compile_options,
103 )
104 correct = 0
105 for sample in eval_loader:
106 correct += compiled_eval_step(sample)["correct"].item()
107 print(
108 f"Correct: {correct} / {len(eval_loader.dataset)}. "
109 f"Accuracy: {correct / len(eval_loader.dataset)}"
110 )
111 assert 0.94 < correct / len(eval_loader.dataset)
112
113
114if __name__ == "__main__":
115 parser = argparse.ArgumentParser(
116 description="A standalone MNIST training and inference script."
117 )
118 parser.add_argument(
119 "--outdir",
120 type=str,
121 default="/tmp/mlsdk_mnist",
122 help="Path to store compiled and trained results",
123 )
124 parser.add_argument(
125 "--option_json",
126 type=Path,
127 default=None,
128 help="""
129 Path to a JSON file specifying compilation configs,
130 e.g. /opt/pfn/pfcomp/codegen/preset_options/O1.json
131 """,
132 )
133 parser.add_argument(
134 "--device", type=str, default="mncore2:auto", help="device_name for MNDevice"
135 )
136 args = parser.parse_args()
137 main(args.outdir, args.option_json, args.device)