8.1.11. Example: Training MNIST
mnist.py から MLSDK API を取り除き、PyTorch で学習を行うサンプルプログラム (移行作業チュートリアル で使用します)
Example: MNIST on MN-Core 2 と同様ですが、単に checkpoint.pt を --output で指定された場所に保存するだけです。 (デフォルトでは /tmp/mlsdk_mnist_train/checkpoint.pt)
実行方法
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist_train.py
想定出力
学習中のログ
Loss curve が Example: MNIST on MN-Core 2 のものと異なることがありますが、これは異なるバックエンドが使用されているためです。
epoch 0, iter 0, loss 2.29758358001709
epoch 0, iter 100, loss 0.6065061688423157
...
epoch 9, iter 900, loss 0.12388602644205093
epoch 9, loss 0.12544165551662445
チェックポイント (
checkpoint.pt)学習が正常に完了したかは
mnist_infer.pyを使ってチェックします。Accuracy指標が 0.94 よりも大きければ良いです。
関連リンク
-
MLSDK API を段階的に導入する際の参考資料です。
サンプルプログラム
1import argparse
2import os
3import random
4from pathlib import Path
5from typing import Mapping, Optional
6
7import numpy as np
8import torch
9from mlsdk import storage
10from mnist_common import MNCoreClassifier, mnist_loaders
11
12torch.manual_seed(0)
13random.seed(0)
14np.random.seed(0)
15
16
17def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
18 batch_size = 64
19 eval_batch_size = 125
20
21 train_loader, _ = mnist_loaders(batch_size, eval_batch_size)
22
23 model_with_loss_fn = MNCoreClassifier()
24 model_with_loss_fn.train()
25
26 optimizer = torch.optim.SGD(model_with_loss_fn.parameters(), 0.1, 0.9, 0.0)
27
28 def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
29 x = inp["x"]
30 t = inp["t"]
31 optimizer.zero_grad()
32 output = model_with_loss_fn(x, t)
33 loss = output["loss"]
34 loss.backward()
35 optimizer.step()
36 return {"loss": loss}
37
38 for epoch in range(10):
39 loss = 0.0
40 for i, sample in enumerate(train_loader):
41 curr_loss = train_step(sample)["loss"]
42 loss += (curr_loss - loss) / (i + 1)
43 if i % 100 == 0:
44 print(f"epoch {epoch}, iter {i:4}, loss {loss}")
45 print(f"epoch {epoch}, loss {loss}")
46
47 os.makedirs(outdir, exist_ok=True)
48 torch.save(
49 {
50 "model_state_dict": model_with_loss_fn.state_dict(),
51 "optim_state_dict": optimizer.state_dict(),
52 },
53 storage.path(outdir) / "checkpoint.pt",
54 )
55
56
57if __name__ == "__main__":
58 parser = argparse.ArgumentParser(
59 description="""
60 A script designed to be used with mnist_infer.py,
61 specifically for running MNIST training operations.
62 """
63 )
64 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_train")
65 parser.add_argument("--option_json", type=Path, default=None)
66 parser.add_argument("--device", type=str, default="mncore2:auto")
67 args = parser.parse_args()
68 main(args.outdir, args.option_json, args.device)