8.1.2. Example: MNIST on MN-Core 2
A sample program demonstrating training and inference operations on the MNIST dataset using MN-Core 2.
Training results are saved to the checkpoint.pt file located in the directory specified by the --outdir flag (default is /tmp/mlsdk_mnist/checkpoint.pt).
Execution Method
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist.py
Expected Output
Training log output
epoch 0, iter 0, loss 2.3125
epoch 0, iter 100, loss 0.6226431969368814
...
epoch 9, iter 900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
Inference results
Correct: 9609 / 10000. Accuracy: 0.9609
Related Links
Sample Program
1import argparse
2import random
3from pathlib import Path
4from typing import Mapping, Optional
5
6import numpy as np
7import torch
8from mlsdk import (
9 Context,
10 MNCoreSGD,
11 MNDevice,
12 set_buffer_name_in_optimizer,
13 set_tensor_name_in_module,
14 storage,
15)
16from mnist_common import MNCoreClassifier, mnist_loaders
17
18torch.manual_seed(0)
19random.seed(0)
20np.random.seed(0)
21
22
23def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
24 batch_size = 64
25 eval_batch_size = 125
26
27 device = MNDevice(device_str)
28 context = Context(device)
29 Context.switch_context(context)
30
31 train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
32
33 model_with_loss_fn = MNCoreClassifier()
34 model_with_loss_fn.train()
35 set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
36 for p in model_with_loss_fn.parameters():
37 context.register_param(p)
38 for b in model_with_loss_fn.buffers():
39 context.register_buffer(b)
40
41 optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
42 set_buffer_name_in_optimizer(optimizer, "optimizer")
43 context.register_optimizer_buffers(optimizer)
44
45 def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
46 x = inp["x"]
47 t = inp["t"]
48 optimizer.zero_grad()
49 output = model_with_loss_fn(x, t)
50 loss = output["loss"]
51 loss.backward()
52 optimizer.step()
53 return {"loss": loss}
54
55 compile_options = {}
56 if option_json_path is not None:
57 compile_options["option_json"] = str(option_json_path)
58
59 sample = next(iter(train_loader))
60 compiled_train_step = context.compile(
61 train_step,
62 sample,
63 storage.path(outdir) / "train_step",
64 options=compile_options,
65 )
66
67 for epoch in range(10):
68 loss = 0.0
69 for i, sample in enumerate(train_loader):
70 curr_loss = compiled_train_step(sample)["loss"].item()
71 loss += (curr_loss - loss) / (i + 1)
72 if i % 100 == 0:
73 print(f"epoch {epoch}, iter {i:4}, loss {loss}")
74 print(f"epoch {epoch}, loss {loss}")
75
76 context.synchronize()
77
78 torch.save(
79 {
80 "model_state_dict": model_with_loss_fn.state_dict(),
81 "optim_state_dict": optimizer.state_dict(),
82 },
83 storage.path(outdir) / "checkpoint.pt",
84 )
85
86 model_with_loss_fn.eval()
87
88 def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
89 x = inp["x"]
90 t = inp["t"]
91 output = model_with_loss_fn(x, t)
92 y = output["y"]
93 _, predicted = torch.max(y, 1)
94 correct = (predicted == t).sum()
95 return {"correct": correct}
96
97 sample = next(iter(eval_loader))
98 compiled_eval_step = context.compile(
99 eval_step,
100 sample,
101 storage.path(outdir) / "eval_step",
102 options=compile_options,
103 )
104 correct = 0
105 for sample in eval_loader:
106 correct += compiled_eval_step(sample)["correct"].item()
107 print(
108 f"Correct: {correct} / {len(eval_loader.dataset)}. "
109 f"Accuracy: {correct / len(eval_loader.dataset)}"
110 )
111 assert 0.94 < correct / len(eval_loader.dataset)
112
113
114if __name__ == "__main__":
115 parser = argparse.ArgumentParser(
116 description="A standalone MNIST training and inference script."
117 )
118 parser.add_argument(
119 "--outdir",
120 type=str,
121 default="/tmp/mlsdk_mnist",
122 help="Path to store compiled and trained results",
123 )
124 parser.add_argument(
125 "--option_json",
126 type=Path,
127 default=None,
128 help="""
129 Path to a JSON file specifying compilation configs,
130 e.g. /opt/pfn/pfcomp/codegen/preset_options/O1.json
131 """,
132 )
133 parser.add_argument(
134 "--device", type=str, default="mncore2:auto", help="device_name for MNDevice"
135 )
136 args = parser.parse_args()
137 main(args.outdir, args.option_json, args.device)