7.1.9. Example: Inference MNIST
A sample program that removes MLSDK API from mnist.py and performs inference using PyTorch (used in Porting Tutorial).
This example assumes you’ve already tried Example: MNIST on MN-Core 2 once and checkpoint.pt exists (default is /tmp/mlsdk_mnist/checkpoint.pt).
Execution Method
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist_infer.py /tmp/mlsdk_mnist/checkpoint.pt
Expected Output
Inference results (should be identical to those from Example: MNIST on MN-Core 2)
Correct: 9609 / 10000. Accuracy: 0.9609
Related Links
-
This material serves as a reference for gradually introducing MLSDK API.
Sample Program
1import argparse
2import random
3from pathlib import Path
4from typing import Mapping, Optional
5
6import numpy as np
7import torch
8from mnist_common import MNCoreClassifier, mnist_loaders
9
10torch.manual_seed(0)
11random.seed(0)
12np.random.seed(0)
13
14
15def main(
16 checkpoint_path: str, outdir: str, option_json_path: Optional[Path], device_str: str
17) -> None:
18 batch_size = 64
19 eval_batch_size = 125
20
21 _, eval_loader = mnist_loaders(batch_size, eval_batch_size)
22
23 checkpoint = torch.load(checkpoint_path)
24
25 model_with_loss_fn = MNCoreClassifier()
26 model_with_loss_fn.load_state_dict(checkpoint["model_state_dict"])
27 model_with_loss_fn.eval()
28
29 def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
30 x = inp["x"]
31 t = inp["t"]
32 output = model_with_loss_fn(x, t)
33 y = output["y"]
34 _, predicted = torch.max(y, 1)
35 correct = (predicted == t).sum()
36 return {"correct": correct}
37
38 correct = 0
39 for sample in eval_loader:
40 correct += eval_step(sample)["correct"]
41 print(
42 f"Correct: {correct} / {len(eval_loader.dataset)}. "
43 f"Accuracy: {correct / len(eval_loader.dataset)}"
44 )
45 assert 0.94 < correct / len(eval_loader.dataset)
46
47
48if __name__ == "__main__":
49 parser = argparse.ArgumentParser(
50 description="""
51 A script designed to be used with mnist_train.py,
52 specifically for running MNIST inference operations.
53 """
54 )
55 parser.add_argument("checkpoint_path", type=str)
56 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_infer")
57 parser.add_argument("--option_json", type=Path, default=None)
58 parser.add_argument("--device", type=str, default="mncore2:auto")
59 args = parser.parse_args()
60 main(args.checkpoint_path, args.outdir, args.option_json, args.device)