7.1.9. Example: Inference MNIST

A sample program that removes MLSDK API from mnist.py and performs inference using PyTorch (used in Porting Tutorial).

This example assumes you’ve already tried Example: MNIST on MN-Core 2 once and checkpoint.pt exists (default is /tmp/mlsdk_mnist/checkpoint.pt).

Execution Method

$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./exec_with_env.sh python3 mnist_infer.py /tmp/mlsdk_mnist/checkpoint.pt

Expected Output

Correct: 9609 / 10000. Accuracy: 0.9609

Related Links

  • Porting Tutorial

    • This material serves as a reference for gradually introducing MLSDK API.

Sample Program

Listing 7.9 /opt/pfn/pfcomp/codegen/MLSDK/examples/mnist_infer.py
 1import argparse
 2import random
 3from pathlib import Path
 4from typing import Mapping, Optional
 5
 6import numpy as np
 7import torch
 8from mnist_common import MNCoreClassifier, mnist_loaders
 9
10torch.manual_seed(0)
11random.seed(0)
12np.random.seed(0)
13
14
15def main(
16    checkpoint_path: str, outdir: str, option_json_path: Optional[Path], device_str: str
17) -> None:
18    batch_size = 64
19    eval_batch_size = 125
20
21    _, eval_loader = mnist_loaders(batch_size, eval_batch_size)
22
23    checkpoint = torch.load(checkpoint_path)
24
25    model_with_loss_fn = MNCoreClassifier()
26    model_with_loss_fn.load_state_dict(checkpoint["model_state_dict"])
27    model_with_loss_fn.eval()
28
29    def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
30        x = inp["x"]
31        t = inp["t"]
32        output = model_with_loss_fn(x, t)
33        y = output["y"]
34        _, predicted = torch.max(y, 1)
35        correct = (predicted == t).sum()
36        return {"correct": correct}
37
38    correct = 0
39    for sample in eval_loader:
40        correct += eval_step(sample)["correct"]
41    print(
42        f"Correct: {correct} / {len(eval_loader.dataset)}. "
43        f"Accuracy: {correct / len(eval_loader.dataset)}"
44    )
45    assert 0.94 < correct / len(eval_loader.dataset)
46
47
48if __name__ == "__main__":
49    parser = argparse.ArgumentParser(
50        description="""
51        A script designed to be used with mnist_train.py,
52        specifically for running MNIST inference operations.
53        """
54    )
55    parser.add_argument("checkpoint_path", type=str)
56    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist_infer")
57    parser.add_argument("--option_json", type=Path, default=None)
58    parser.add_argument("--device", type=str, default="mncore2:auto")
59    args = parser.parse_args()
60    main(args.checkpoint_path, args.outdir, args.option_json, args.device)