7.1.10. Example: Training MNIST

A sample program that removes MLSDK API from mnist.py and performs training using PyTorch (used in Porting Tutorial).

Similar to Example: MNIST on MN-Core 2, but just outputting checkpoint.pt to a directory specified by --outdir (default is /tmp/mlsdk_mnist_train/checkpoint.pt).

Execution Method

$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist_train.py

Expected Output

epoch 0, iter    0, loss 2.29758358001709
epoch 0, iter  100, loss 0.6065061688423157
...
epoch 9, iter  900, loss 0.12388602644205093
epoch 9, loss 0.12544165551662445
  • Checkpoint file (checkpoint.pt)

    • Supposed to be checked if the training performed properly by using mnist_infer.py

    • Accuracy metric should be larger than 0.94

Related Links

  • Porting Tutorial

    • This material serves as a reference for gradually introducing MLSDK API.

Sample Program

Listing 7.10 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist_train.py
 1import argparse
 2import os
 3import random
 4from pathlib import Path
 5from typing import Mapping, Optional
 6
 7import numpy as np
 8import torch
 9from mlsdk import storage
10from mnist_common import MNCoreClassifier, mnist_loaders
11
12torch.manual_seed(0)
13random.seed(0)
14np.random.seed(0)
15
16
17def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
18    batch_size = 64
19    eval_batch_size = 125
20
21    train_loader, _ = mnist_loaders(batch_size, eval_batch_size)
22
23    model_with_loss_fn = MNCoreClassifier()
24    model_with_loss_fn.train()
25
26    optimizer = torch.optim.SGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
27
28    def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
29        x = inp["x"]
30        t = inp["t"]
31        optimizer.zero_grad()
32        output = model_with_loss_fn(x, t)
33        loss = output["loss"]
34        loss.backward()
35        optimizer.step()
36        return {"loss": loss}
37
38    for epoch in range(10):
39        loss = 0.0
40        for i, sample in enumerate(train_loader):
41            curr_loss = train_step(sample)["loss"]
42            loss += (curr_loss - loss) / (i + 1)
43            if i % 100 == 0:
44                print(f"epoch {epoch}, iter {i:4}, loss {loss}")
45        print(f"epoch {epoch}, loss {loss}")
46
47    os.makedirs(outdir, exist_ok=True)
48    torch.save(
49        {
50            "model_state_dict": model_with_loss_fn.state_dict(),
51            "optim_state_dict": optimizer.state_dict(),
52        },
53        storage.path(outdir) / "checkpoint.pt",
54    )
55
56
57if __name__ == "__main__":
58    parser = argparse.ArgumentParser(
59        description="""
60        A script designed to be used with mnist_infer.py,
61        specifically for running MNIST training operations.
62        """
63    )
64    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_train")
65    parser.add_argument("--option_json", type=Path, default=None)
66    parser.add_argument("--device", type=str, default="mncore2:auto")
67    args = parser.parse_args()
68    main(args.outdir, args.option_json, args.device)