7.1.2. Example: MNIST on MN-Core 2
A sample program demonstrating training and inference operations on the MNIST dataset using MN-Core 2.
Training results are saved to the checkpoint.pt file located in the directory specified by the --outdir flag (default is /tmp/mlsdk_mnist/checkpoint.pt).
Execution Method
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist.py
Expected Output
Training log output
epoch 0, iter 0, loss 2.3125
epoch 0, iter 100, loss 0.6226431969368814
...
epoch 9, iter 900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
Inference results
Correct: 9609 / 10000. Accuracy: 0.9609
Related Links
Sample Program
1import argparse
2import random
3from pathlib import Path
4from typing import Mapping, Optional
5
6import numpy as np
7import torch
8from mlsdk import (
9 Context,
10 MNCoreSGD,
11 MNDevice,
12 set_buffer_name_in_optimizer,
13 set_tensor_name_in_module,
14 storage,
15)
16from mnist_common import MNCoreClassifier, mnist_loaders
17
18torch.manual_seed(0)
19random.seed(0)
20np.random.seed(0)
21
22
23def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
24 batch_size = 64
25 eval_batch_size = 125
26
27 device = MNDevice(device_str)
28 context = Context(device)
29 Context.switch_context(context)
30
31 train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
32
33 model_with_loss_fn = MNCoreClassifier()
34 model_with_loss_fn.train()
35 set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
36 for p in model_with_loss_fn.parameters():
37 context.register_param(p)
38
39 optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
40 set_buffer_name_in_optimizer(optimizer, "optimizer")
41 context.register_optimizer_buffers(optimizer)
42
43 def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
44 x = inp["x"]
45 t = inp["t"]
46 optimizer.zero_grad()
47 output = model_with_loss_fn(x, t)
48 loss = output["loss"]
49 loss.backward()
50 optimizer.step()
51 return {"loss": loss}
52
53 compile_options = {}
54 if option_json_path is not None:
55 compile_options["option_json"] = str(option_json_path)
56
57 sample = next(iter(train_loader))
58 compiled_train_step = context.compile(
59 train_step,
60 sample,
61 storage.path(outdir) / "train_step",
62 options=compile_options,
63 )
64
65 for epoch in range(10):
66 loss = 0.0
67 for i, sample in enumerate(train_loader):
68 curr_loss = compiled_train_step(sample)["loss"].item()
69 loss += (curr_loss - loss) / (i + 1)
70 if i % 100 == 0:
71 print(f"epoch {epoch}, iter {i:4}, loss {loss}")
72 print(f"epoch {epoch}, loss {loss}")
73
74 context.synchronize()
75
76 torch.save(
77 {
78 "model_state_dict": model_with_loss_fn.state_dict(),
79 "optim_state_dict": optimizer.state_dict(),
80 },
81 storage.path(outdir) / "checkpoint.pt",
82 )
83
84 model_with_loss_fn.eval()
85
86 def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
87 x = inp["x"]
88 t = inp["t"]
89 output = model_with_loss_fn(x, t)
90 y = output["y"]
91 _, predicted = torch.max(y, 1)
92 correct = (predicted == t).sum()
93 return {"correct": correct}
94
95 sample = next(iter(eval_loader))
96 compiled_eval_step = context.compile(
97 eval_step,
98 sample,
99 storage.path(outdir) / "eval_step",
100 options=compile_options,
101 )
102 correct = 0
103 for sample in eval_loader:
104 correct += compiled_eval_step(sample)["correct"].item()
105 print(
106 f"Correct: {correct} / {len(eval_loader.dataset)}. "
107 f"Accuracy: {correct / len(eval_loader.dataset)}"
108 )
109 assert 0.94 < correct / len(eval_loader.dataset)
110
111
112if __name__ == "__main__":
113 parser = argparse.ArgumentParser(
114 description="A standalone MNIST training and inference script."
115 )
116 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist")
117 parser.add_argument("--option_json", type=Path, default=None)
118 parser.add_argument("--device", type=str, default="mncore2:auto")
119 args = parser.parse_args()
120 main(args.outdir, args.option_json, args.device)