8.1.2. Example: MNIST on MN-Core 2

A sample program demonstrating training and inference operations on the MNIST dataset using MN-Core 2.

Training results are saved to the checkpoint.pt file located in the directory specified by the --outdir flag (default is /tmp/mlsdk_mnist/checkpoint.pt).

Execution Method

$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist.py

Expected Output

  • Training log output

epoch 0, iter    0, loss 2.3125
epoch 0, iter  100, loss 0.6226431969368814
...
epoch 9, iter  900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
  • Inference results

Correct: 9609 / 10000. Accuracy: 0.9609

Related Links

Sample Program

Listing 8.2 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist.py
  1import argparse
  2import random
  3from pathlib import Path
  4from typing import Mapping, Optional
  5
  6import numpy as np
  7import torch
  8from mlsdk import (
  9    Context,
 10    MNCoreSGD,
 11    MNDevice,
 12    set_buffer_name_in_optimizer,
 13    set_tensor_name_in_module,
 14    storage,
 15)
 16from mnist_common import MNCoreClassifier, mnist_loaders
 17
 18torch.manual_seed(0)
 19random.seed(0)
 20np.random.seed(0)
 21
 22
 23def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
 24    batch_size = 64
 25    eval_batch_size = 125
 26
 27    device = MNDevice(device_str)
 28    context = Context(device)
 29    Context.switch_context(context)
 30
 31    train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
 32
 33    model_with_loss_fn = MNCoreClassifier()
 34    model_with_loss_fn.train()
 35    set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
 36    for p in model_with_loss_fn.parameters():
 37        context.register_param(p)
 38    for b in model_with_loss_fn.buffers():
 39        context.register_buffer(b)
 40
 41    optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
 42    set_buffer_name_in_optimizer(optimizer, "optimizer")
 43    context.register_optimizer_buffers(optimizer)
 44
 45    def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 46        x = inp["x"]
 47        t = inp["t"]
 48        optimizer.zero_grad()
 49        output = model_with_loss_fn(x, t)
 50        loss = output["loss"]
 51        loss.backward()
 52        optimizer.step()
 53        return {"loss": loss}
 54
 55    compile_options = {}
 56    if option_json_path is not None:
 57        compile_options["option_json"] = str(option_json_path)
 58
 59    sample = next(iter(train_loader))
 60    compiled_train_step = context.compile(
 61        train_step,
 62        sample,
 63        storage.path(outdir) / "train_step",
 64        options=compile_options,
 65    )
 66
 67    for epoch in range(10):
 68        loss = 0.0
 69        for i, sample in enumerate(train_loader):
 70            curr_loss = compiled_train_step(sample)["loss"].item()
 71            loss += (curr_loss - loss) / (i + 1)
 72            if i % 100 == 0:
 73                print(f"epoch {epoch}, iter {i:4}, loss {loss}")
 74        print(f"epoch {epoch}, loss {loss}")
 75
 76    context.synchronize()
 77
 78    torch.save(
 79        {
 80            "model_state_dict": model_with_loss_fn.state_dict(),
 81            "optim_state_dict": optimizer.state_dict(),
 82        },
 83        storage.path(outdir) / "checkpoint.pt",
 84    )
 85
 86    model_with_loss_fn.eval()
 87
 88    def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 89        x = inp["x"]
 90        t = inp["t"]
 91        output = model_with_loss_fn(x, t)
 92        y = output["y"]
 93        _, predicted = torch.max(y, 1)
 94        correct = (predicted == t).sum()
 95        return {"correct": correct}
 96
 97    sample = next(iter(eval_loader))
 98    compiled_eval_step = context.compile(
 99        eval_step,
100        sample,
101        storage.path(outdir) / "eval_step",
102        options=compile_options,
103    )
104    correct = 0
105    for sample in eval_loader:
106        correct += compiled_eval_step(sample)["correct"].item()
107    print(
108        f"Correct: {correct} / {len(eval_loader.dataset)}. "
109        f"Accuracy: {correct / len(eval_loader.dataset)}"
110    )
111    assert 0.94 < correct / len(eval_loader.dataset)
112
113
114if __name__ == "__main__":
115    parser = argparse.ArgumentParser(
116        description="A standalone MNIST training and inference script."
117    )
118    parser.add_argument(
119        "--outdir",
120        type=str,
121        default="/tmp/mlsdk_mnist",
122        help="Path to store compiled and trained results",
123    )
124    parser.add_argument(
125        "--option_json",
126        type=Path,
127        default=None,
128        help="""
129        Path to a JSON file specifying compilation configs,
130        e.g. /opt/pfn/pfcomp/codegen/preset_options/O1.json
131        """,
132    )
133    parser.add_argument(
134        "--device", type=str, default="mncore2:auto", help="device_name for MNDevice"
135    )
136    args = parser.parse_args()
137    main(args.outdir, args.option_json, args.device)