6.4.2. FX2ONNX Linter
FX2ONNX linter は、 mlsdk.Context.compile() に渡す関数を対象とする、コンパイル前の診断ツールです。サンプル入力を用いて関数をトレースし、FX2ONNX exporter と互換性のない Python のコードを報告します。
背景となる exporter の制約については、 FX2ONNX Exporter および Compile Target Limitations & Code Modification を参照してください。
6.4.2.1. 目的
FX2ONNX は torch.Tensor に対する操作をトレースすることで ONNX グラフを構築しますが、以下のパターンは FX2ONNX で正しく表現できません。
Python scalar の作成
Global Python Object の使用
動的 shape (将来的なサポートの予定あり)
関数スコープ外の変数への
torch.Tensorの代入
さらに、MLSDK はコンパイル対象の関数が dict[str, torch.Tensor] 型の位置引数を 1 つ持ち、 dict[str, torch.Tensor] を返すことを期待します。 linter は、コンパイル前にソースプログラムを修正できるように、これらのパターンを警告として報告します。
6.4.2.2. 基本的な使い方
import torch
from fx2onnx.linter import LintLevel, lint
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
return {"out": x + 1}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
lint() は対象となる関数を 1 回実行し警告を生成しますが、これは実際のテンソル計算を行う実行ではなく、Dynamo を利用した fake tensor によるトレース実行です。代表的なサンプル入力を指定し、 result.dump() を呼び出すことで、人間が読める形式のレポートを出力できます。
fx2onnx.linter.LintLevel は、linter の警告フィルタレベルを表します。詳細は API リファレンス を参照してください。
6.4.2.3. 警告と例
スタックトレースに表示される正確なファイル名と行番号は、ソースファイルに依存します。
6.4.2.3.1. Python scalar の作成
pattern 1.
以下の関数は、 x.item() によって Python scalar を作成するため、FX2ONNX で正しく export できません。
import torch
from fx2onnx.linter import LintLevel, lint
a = 0
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
global a
x = inp["x"]
a += x.item()
return {"out": x + a}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].a` accessed at:
File "example.py", line 9, in f
WARNING: Detected 1 python scalar creation(s) from tensor(s):
File "example.py", line 9, in f
トレースされる計算に関わる値は torch.Tensor のまま保持してください。たとえば、状態自体をグラフの一部にする必要がある場合は、scalar の加算を a.add_(x) のような tensor 操作に置き換えてください。
pattern 2.
以下の関数は、条件分岐が torch.Tensor の値に依存しているため、FX2ONNX で正しく export できません。 torch.Tensor を使用する条件は暗黙的に Python scalar に変換されます。
import torch
from fx2onnx.linter import LintLevel, lint
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
if (x > 0).all():
y = x + 1
else:
y = x - 1
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected 1 python scalar creation(s) from tensor(s):
File "example.py", line 6, in f
tensor に依存する分岐は、 torch.where() などの tensor 操作に置き換えてください。
6.4.2.3.2. Global Python Object の使用
pattern 1.
以下の関数は、export される ONNX グラフが trace 時点の cond の値に固定されるため、FX2ONNX で正しく export できません。
import torch
from fx2onnx.linter import LintLevel, lint
cond = False
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
if cond:
y = x + 1
else:
y = x - 1
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].cond` accessed at:
File "example.py", line 8, in f
pattern 2.
以下の関数は、export される ONNX グラフが trace 時点の a の値に固定されるため、FX2ONNX で正しく export できません。
import torch
from fx2onnx.linter import LintLevel, lint
a = 0
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
y = x + a
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].a` accessed at:
File "example.py", line 8, in f
pattern 1 と 2 については、ユーザーがその挙動を意図している場合、これらの警告は無視して構いません。
pattern 3.
以下の関数は、FX2ONNX が torch.Tensor ではない値に対する操作を追跡できないため、正しく export できません。
import torch
from fx2onnx.linter import LintLevel, lint
a = 0
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
global a
x = inp["x"]
y = x + a
a += 1
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].a` accessed at:
File "example.py", line 9, in f
pattern 4.
torch.optim 以下の optimizer や、 torch.optim.lr_scheduler.LRScheduler から派生した scheduler を使用する関数は、FX2ONNX で正しく export できません。これらの class は、FX2ONNX がグラフ操作として表現できない操作を行います。
import torch
from fx2onnx.linter import LintLevel, lint
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
model = Model()
model.train()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.001))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer=optimizer, lr_lambda=lambda epoch: 0.95**epoch
)
def train_step(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
outputs = model(inp["x"])
loss = criterion.forward(outputs, inp["y"])
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
return {"outputs": outputs, "loss": loss}
x = torch.randn(1, 10)
y = torch.randn(1, 10)
result = lint(train_step, args=({"x": x, "y": y},), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected ... reference(s) to global object(s):
...
WARNING: `torch.optim.Optimizer` is used at:
File "example.py", line 27, in train_step
WARNING: `torch.optim.lr_scheduler.LRScheduler` is used at:
File "example.py", line 26, in train_step
学習用の関数をコンパイルする際は、 fx2onnx.optim や MLSDK が提供する optimizer および scheduler utility など、FX2ONNX 向けに設計された実装を使用してください。
6.4.2.3.3. 動的 shape
FX2ONNX は現時点では静的 shape を仮定しています。動的 shape のサポートは、将来のバージョンで予定されています。
pattern 1.
以下の関数は、bool mask による indexing が動的 shape を生成するため、現時点では FX2ONNX で正しく export できません。
import torch
from fx2onnx.linter import LintLevel, lint
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
y = inp["y"]
z = y.clone()
z[x > 0] = x[x > 0]
return {"out": z}
x = torch.randn(1, 10)
y = torch.randn(1, 10)
result = lint(f, args=({"x": x, "y": y},), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected 1 dynamic shape(s):
File "example.py", line 8, in f
このような indexing には、 torch.where(x > 0, x, y) のような静的 shape の tensor 操作を優先して使用してください。
6.4.2.3.4. 関数スコープ外の変数への torch.Tensor の代入
pattern 1.
以下の関数は、FX2ONNX が関数スコープ外の変数への torch.Tensor の代入を追跡できないため、正しく export できません。
import torch
from fx2onnx.linter import LintLevel, lint
a = None
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
global a
x = inp["x"]
a = x + 1
return {"out": x}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Detected 1 tensor leak(s) to global/nonlocal variable(s):
File "example.py", line 9, in f
永続的な tensor が必要な場合は、代入先を torch.Tensor として保持し、 copy_() や add_() などの追跡可能な tensor 操作で置き換えてください。
6.4.2.3.5. 入出力型の不一致
MLSDK は、コンパイル対象の関数が dict[str, torch.Tensor] 型の位置引数を 1 つ受け取り、 dict[str, torch.Tensor] を返すことを期待します。そのため、以下の関数は入出力 check に失敗します。
import torch
from fx2onnx.linter import LintLevel, lint
def f(
inp: dict[str, object],
extra: torch.Tensor,
*,
y: torch.Tensor,
) -> dict[str, object]:
return {"out": extra + y, "z": inp["x"] + 1}
args = ({"x": 1}, torch.zeros(1))
kwargs = {"y": torch.zeros(1)}
result = lint(f, args=args, kwargs=kwargs, lint_level=LintLevel.INFO)
result.dump()
このコードは、以下のような警告を出力します。
WARNING: Input/output types do not match the expected signature `(args=(dict[str, torch.Tensor],), kwargs={}) -> dict[str, torch.Tensor]`:
Expected exactly one positional argument, but got 2.
Expected args[0]['x'] to be torch.Tensor, but got int.
Expected no keyword arguments, but got 'y'.
Expected output['z'] to be torch.Tensor, but got int.