7.1.2. Example: MNIST on MN-Core 2

A sample program demonstrating training and inference operations on the MNIST dataset using MN-Core 2.

Training results are saved to the checkpoint.pt file located in the directory specified by the --outdir flag (default is /tmp/mlsdk_mnist/checkpoint.pt).

Execution Method

$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist.py

Expected Output

  • Training log output

epoch 0, iter    0, loss 2.3125
epoch 0, iter  100, loss 0.6226431969368814
...
epoch 9, iter  900, loss 0.10909322893182918
epoch 9, loss 0.11064393848594248
  • Inference results

Correct: 9609 / 10000. Accuracy: 0.9609

Related Links

Sample Program

Listing 7.2 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist.py
  1import argparse
  2import random
  3from pathlib import Path
  4from typing import Mapping, Optional
  5
  6import numpy as np
  7import torch
  8from mlsdk import (
  9    Context,
 10    MNCoreSGD,
 11    MNDevice,
 12    set_buffer_name_in_optimizer,
 13    set_tensor_name_in_module,
 14    storage,
 15)
 16from mnist_common import MNCoreClassifier, mnist_loaders
 17
 18torch.manual_seed(0)
 19random.seed(0)
 20np.random.seed(0)
 21
 22
 23def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
 24    batch_size = 64
 25    eval_batch_size = 125
 26
 27    device = MNDevice(device_str)
 28    context = Context(device)
 29    Context.switch_context(context)
 30
 31    train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)
 32
 33    model_with_loss_fn = MNCoreClassifier()
 34    model_with_loss_fn.train()
 35    set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
 36    for p in model_with_loss_fn.parameters():
 37        context.register_param(p)
 38
 39    optimizer = MNCoreSGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
 40    set_buffer_name_in_optimizer(optimizer, "optimizer")
 41    context.register_optimizer_buffers(optimizer)
 42
 43    def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 44        x = inp["x"]
 45        t = inp["t"]
 46        optimizer.zero_grad()
 47        output = model_with_loss_fn(x, t)
 48        loss = output["loss"]
 49        loss.backward()
 50        optimizer.step()
 51        return {"loss": loss}
 52
 53    compile_options = {}
 54    if option_json_path is not None:
 55        compile_options["option_json"] = str(option_json_path)
 56
 57    sample = next(iter(train_loader))
 58    compiled_train_step = context.compile(
 59        train_step,
 60        sample,
 61        storage.path(outdir) / "train_step",
 62        options=compile_options,
 63    )
 64
 65    for epoch in range(10):
 66        loss = 0.0
 67        for i, sample in enumerate(train_loader):
 68            curr_loss = compiled_train_step(sample)["loss"].item()
 69            loss += (curr_loss - loss) / (i + 1)
 70            if i % 100 == 0:
 71                print(f"epoch {epoch}, iter {i:4}, loss {loss}")
 72        print(f"epoch {epoch}, loss {loss}")
 73
 74    context.synchronize()
 75
 76    torch.save(
 77        {
 78            "model_state_dict": model_with_loss_fn.state_dict(),
 79            "optim_state_dict": optimizer.state_dict(),
 80        },
 81        storage.path(outdir) / "checkpoint.pt",
 82    )
 83
 84    model_with_loss_fn.eval()
 85
 86    def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 87        x = inp["x"]
 88        t = inp["t"]
 89        output = model_with_loss_fn(x, t)
 90        y = output["y"]
 91        _, predicted = torch.max(y, 1)
 92        correct = (predicted == t).sum()
 93        return {"correct": correct}
 94
 95    sample = next(iter(eval_loader))
 96    compiled_eval_step = context.compile(
 97        eval_step,
 98        sample,
 99        storage.path(outdir) / "eval_step",
100        options=compile_options,
101    )
102    correct = 0
103    for sample in eval_loader:
104        correct += compiled_eval_step(sample)["correct"].item()
105    print(
106        f"Correct: {correct} / {len(eval_loader.dataset)}. "
107        f"Accuracy: {correct / len(eval_loader.dataset)}"
108    )
109    assert 0.94 < correct / len(eval_loader.dataset)
110
111
112if __name__ == "__main__":
113    parser = argparse.ArgumentParser(
114        description="A standalone MNIST training and inference script."
115    )
116    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist")
117    parser.add_argument("--option_json", type=Path, default=None)
118    parser.add_argument("--device", type=str, default="mncore2:auto")
119    args = parser.parse_args()
120    main(args.outdir, args.option_json, args.device)