6.4.2. FX2ONNX Linter
The FX2ONNX linter is a pre-compilation diagnostic tool for functions that will
be passed to mlsdk.Context.compile(). It traces the target function
with representative sample inputs and reports Python constructs that are known
to be incompatible with FX2ONNX export, or with the MLSDK function signature.
For the underlying exporter restrictions, see FX2ONNX Exporter and Compile Target Limitations & Code Modification.
6.4.2.1. Purpose
FX2ONNX builds an ONNX graph by tracing operations on torch.Tensor objects.
The following patterns cannot be represented correctly by FX2ONNX:
Creation of Python scalars
Global Python object usage
Dynamic shapes (though future support is planned)
Assignment of
torch.Tensorvalues to variables outside the function scope
In addition, MLSDK expects the compiled function to have one positional argument
of type dict[str, torch.Tensor] and to return dict[str, torch.Tensor].
The linter reports these patterns as warnings so that the source program can be
rewritten before compilation.
6.4.2.2. Basic Usage
import torch
from fx2onnx.linter import LintLevel, lint
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
return {"out": x + 1}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
lint() runs the target function once through Dynamo tracing with fake
tensors, without performing actual tensor computation.
Use representative sample inputs, and call result.dump() to print a
human-readable report.
fx2onnx.linter.LintLevel represents the linter warning filter
level. For details, see API Reference.
6.4.2.3. Warnings and Examples
The exact file names and line numbers in the stack traces depend on your source file.
6.4.2.3.1. Creation of Python scalars
pattern 1.
The following function cannot be correctly exported by FX2ONNX because it
creates a Python scalar with x.item():
import torch
from fx2onnx.linter import LintLevel, lint
a = 0
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
global a
x = inp["x"]
a += x.item()
return {"out": x + a}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
This emits warnings similar to the following:
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].a` accessed at:
File "example.py", line 9, in f
WARNING: Detected 1 python scalar creation(s) from tensor(s):
File "example.py", line 9, in f
Keep values that participate in the traced computation as torch.Tensor
objects. For example, replace scalar accumulation with tensor operations such
as a.add_(x) when the state itself must be part of the graph.
pattern 2.
The following function cannot be correctly exported by FX2ONNX because the
conditional branch depends on a torch.Tensor value. The branch condition is
implicitly converted to a Python scalar:
import torch
from fx2onnx.linter import LintLevel, lint
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
if (x > 0).all():
y = x + 1
else:
y = x - 1
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
This emits a warning similar to the following:
WARNING: Detected 1 python scalar creation(s) from tensor(s):
File "example.py", line 6, in f
Rewrite tensor-dependent branches with tensor operations such as
torch.where().
6.4.2.3.2. Global Python object usage
pattern 1.
The following function cannot be correctly exported by FX2ONNX because the
exported ONNX graph is fixed to the value of cond at trace time:
import torch
from fx2onnx.linter import LintLevel, lint
cond = False
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
if cond:
y = x + 1
else:
y = x - 1
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
This emits a warning similar to the following:
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].cond` accessed at:
File "example.py", line 8, in f
pattern 2.
The following function cannot be correctly exported by FX2ONNX because the
exported ONNX graph is fixed to the value of a at trace time:
import torch
from fx2onnx.linter import LintLevel, lint
a = 0
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
y = x + a
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
This emits a warning similar to the following:
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].a` accessed at:
File "example.py", line 8, in f
For patterns 1 and 2, these warnings can be ignored if the user intends the behavior.
pattern 3.
The following function cannot be correctly exported by FX2ONNX because FX2ONNX
cannot track operations on values that are not torch.Tensor objects:
import torch
from fx2onnx.linter import LintLevel, lint
a = 0
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
global a
x = inp["x"]
y = x + a
a += 1
return {"out": y}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
This emits a warning similar to the following:
WARNING: Detected 1 reference(s) to global object(s):
Global object `G['__import___main__'].a` accessed at:
File "example.py", line 9, in f
pattern 4.
Functions using optimizers under torch.optim or schedulers derived from
torch.optim.lr_scheduler.LRScheduler cannot be correctly exported by
FX2ONNX. These classes update Python-side state that FX2ONNX cannot represent
as graph operations.
import torch
from fx2onnx.linter import LintLevel, lint
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
model = Model()
model.train()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.001))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer=optimizer, lr_lambda=lambda epoch: 0.95**epoch
)
def train_step(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
outputs = model(inp["x"])
loss = criterion.forward(outputs, inp["y"])
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
return {"outputs": outputs, "loss": loss}
x = torch.randn(1, 10)
y = torch.randn(1, 10)
result = lint(train_step, args=({"x": x, "y": y},), lint_level=LintLevel.INFO)
result.dump()
This emits warnings similar to the following:
WARNING: Detected ... reference(s) to global object(s):
...
WARNING: `torch.optim.Optimizer` is used at:
File "example.py", line 27, in train_step
WARNING: `torch.optim.lr_scheduler.LRScheduler` is used at:
File "example.py", line 26, in train_step
Use optimizer and scheduler implementations that are designed for FX2ONNX, such
as the optimizer and scheduler utilities provided by fx2onnx.optim or MLSDK,
when compiling training steps.
6.4.2.3.3. Dynamic shapes
FX2ONNX currently assumes static shapes. Support for dynamic shapes is planned for a future version.
pattern 1.
The following function cannot be correctly exported by FX2ONNX at present because bool-mask indexing produces a dynamic shape:
import torch
from fx2onnx.linter import LintLevel, lint
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
x = inp["x"]
y = inp["y"]
z = y.clone()
z[x > 0] = x[x > 0]
return {"out": z}
x = torch.randn(1, 10)
y = torch.randn(1, 10)
result = lint(f, args=({"x": x, "y": y},), lint_level=LintLevel.INFO)
result.dump()
This emits a warning similar to the following:
WARNING: Detected 1 dynamic shape(s):
File "example.py", line 8, in f
For this kind of selection, prefer static-shape tensor operations such as
torch.where(x > 0, x, y).
6.4.2.3.4. Assignment of torch.Tensor to variables outside function scope
pattern 1.
The following function cannot be correctly exported by FX2ONNX because FX2ONNX
cannot track assignment of a torch.Tensor value to a variable outside the
function scope:
import torch
from fx2onnx.linter import LintLevel, lint
a = None
def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
global a
x = inp["x"]
a = x + 1
return {"out": x}
sample = {"x": torch.zeros(1)}
result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()
This emits a warning similar to the following:
WARNING: Detected 1 tensor leak(s) to global/nonlocal variable(s):
File "example.py", line 9, in f
If persistent tensor state is required, keep the destination as a
torch.Tensor and update it with traceable tensor operations such as
copy_() or add_().
6.4.2.3.5. Input and output type mismatch
MLSDK expects a compile target to receive exactly one positional argument of
type dict[str, torch.Tensor] and return dict[str, torch.Tensor]. The
following function is therefore reported by the input/output checker:
import torch
from fx2onnx.linter import LintLevel, lint
def f(
inp: dict[str, object],
extra: torch.Tensor,
*,
y: torch.Tensor,
) -> dict[str, object]:
return {"out": extra + y, "z": inp["x"] + 1}
args = ({"x": 1}, torch.zeros(1))
kwargs = {"y": torch.zeros(1)}
result = lint(f, args=args, kwargs=kwargs, lint_level=LintLevel.INFO)
result.dump()
This emits warnings similar to the following:
WARNING: Input/output types do not match the expected signature `(args=(dict[str, torch.Tensor],), kwargs={}) -> dict[str, torch.Tensor]`:
Expected exactly one positional argument, but got 2.
Expected args[0]['x'] to be torch.Tensor, but got int.
Expected no keyword arguments, but got 'y'.
Expected output['z'] to be torch.Tensor, but got int.