6.4.2. FX2ONNX Linter

The FX2ONNX linter is a pre-compilation diagnostic tool for functions that will be passed to mlsdk.Context.compile(). It traces the target function with representative sample inputs and reports Python constructs that are known to be incompatible with FX2ONNX export, or with the MLSDK function signature.

For the underlying exporter restrictions, see FX2ONNX Exporter and Compile Target Limitations & Code Modification.

6.4.2.1. Purpose

FX2ONNX builds an ONNX graph by tracing operations on torch.Tensor objects. The following patterns cannot be represented correctly by FX2ONNX:

  • Creation of Python scalars

  • Global Python object usage

  • Dynamic shapes (though future support is planned)

  • Assignment of torch.Tensor values to variables outside the function scope

In addition, MLSDK expects the compiled function to have one positional argument of type dict[str, torch.Tensor] and to return dict[str, torch.Tensor]. The linter reports these patterns as warnings so that the source program can be rewritten before compilation.

6.4.2.2. Basic Usage

import torch
from fx2onnx.linter import LintLevel, lint

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    x = inp["x"]
    return {"out": x + 1}

sample = {"x": torch.zeros(1)}

result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()

lint() runs the target function once through Dynamo tracing with fake tensors, without performing actual tensor computation. Use representative sample inputs, and call result.dump() to print a human-readable report.

fx2onnx.linter.LintLevel represents the linter warning filter level. For details, see API Reference.

6.4.2.3. Warnings and Examples

The exact file names and line numbers in the stack traces depend on your source file.

6.4.2.3.1. Creation of Python scalars

pattern 1.

The following function cannot be correctly exported by FX2ONNX because it creates a Python scalar with x.item():

import torch
from fx2onnx.linter import LintLevel, lint

a = 0

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    global a
    x = inp["x"]
    a += x.item()
    return {"out": x + a}

sample = {"x": torch.zeros(1)}

result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()

This emits warnings similar to the following:

WARNING: Detected 1 reference(s) to global object(s):
  Global object `G['__import___main__'].a` accessed at:
    File "example.py", line 9, in f

WARNING: Detected 1 python scalar creation(s) from tensor(s):
  File "example.py", line 9, in f

Keep values that participate in the traced computation as torch.Tensor objects. For example, replace scalar accumulation with tensor operations such as a.add_(x) when the state itself must be part of the graph.

pattern 2.

The following function cannot be correctly exported by FX2ONNX because the conditional branch depends on a torch.Tensor value. The branch condition is implicitly converted to a Python scalar:

import torch
from fx2onnx.linter import LintLevel, lint

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    x = inp["x"]
    if (x > 0).all():
        y = x + 1
    else:
        y = x - 1
    return {"out": y}

sample = {"x": torch.zeros(1)}

result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()

This emits a warning similar to the following:

WARNING: Detected 1 python scalar creation(s) from tensor(s):
  File "example.py", line 6, in f

Rewrite tensor-dependent branches with tensor operations such as torch.where().

6.4.2.3.2. Global Python object usage

pattern 1.

The following function cannot be correctly exported by FX2ONNX because the exported ONNX graph is fixed to the value of cond at trace time:

import torch
from fx2onnx.linter import LintLevel, lint

cond = False

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    x = inp["x"]
    if cond:
        y = x + 1
    else:
        y = x - 1
    return {"out": y}

sample = {"x": torch.zeros(1)}

result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()

This emits a warning similar to the following:

WARNING: Detected 1 reference(s) to global object(s):
  Global object `G['__import___main__'].cond` accessed at:
    File "example.py", line 8, in f

pattern 2.

The following function cannot be correctly exported by FX2ONNX because the exported ONNX graph is fixed to the value of a at trace time:

import torch
from fx2onnx.linter import LintLevel, lint

a = 0

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    x = inp["x"]
    y = x + a
    return {"out": y}

sample = {"x": torch.zeros(1)}

result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()

This emits a warning similar to the following:

WARNING: Detected 1 reference(s) to global object(s):
  Global object `G['__import___main__'].a` accessed at:
    File "example.py", line 8, in f

For patterns 1 and 2, these warnings can be ignored if the user intends the behavior.

pattern 3.

The following function cannot be correctly exported by FX2ONNX because FX2ONNX cannot track operations on values that are not torch.Tensor objects:

import torch
from fx2onnx.linter import LintLevel, lint

a = 0

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    global a
    x = inp["x"]
    y = x + a
    a += 1
    return {"out": y}

sample = {"x": torch.zeros(1)}

result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()

This emits a warning similar to the following:

WARNING: Detected 1 reference(s) to global object(s):
  Global object `G['__import___main__'].a` accessed at:
    File "example.py", line 9, in f

pattern 4.

Functions using optimizers under torch.optim or schedulers derived from torch.optim.lr_scheduler.LRScheduler cannot be correctly exported by FX2ONNX. These classes update Python-side state that FX2ONNX cannot represent as graph operations.

import torch
from fx2onnx.linter import LintLevel, lint

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

model = Model()
model.train()

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.001))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer=optimizer, lr_lambda=lambda epoch: 0.95**epoch
)

def train_step(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    outputs = model(inp["x"])
    loss = criterion.forward(outputs, inp["y"])
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    return {"outputs": outputs, "loss": loss}

x = torch.randn(1, 10)
y = torch.randn(1, 10)

result = lint(train_step, args=({"x": x, "y": y},), lint_level=LintLevel.INFO)
result.dump()

This emits warnings similar to the following:

WARNING: Detected ... reference(s) to global object(s):
  ...

WARNING: `torch.optim.Optimizer` is used at:
  File "example.py", line 27, in train_step

WARNING: `torch.optim.lr_scheduler.LRScheduler` is used at:
  File "example.py", line 26, in train_step

Use optimizer and scheduler implementations that are designed for FX2ONNX, such as the optimizer and scheduler utilities provided by fx2onnx.optim or MLSDK, when compiling training steps.

6.4.2.3.3. Dynamic shapes

FX2ONNX currently assumes static shapes. Support for dynamic shapes is planned for a future version.

pattern 1.

The following function cannot be correctly exported by FX2ONNX at present because bool-mask indexing produces a dynamic shape:

import torch
from fx2onnx.linter import LintLevel, lint

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    x = inp["x"]
    y = inp["y"]
    z = y.clone()
    z[x > 0] = x[x > 0]
    return {"out": z}

x = torch.randn(1, 10)
y = torch.randn(1, 10)

result = lint(f, args=({"x": x, "y": y},), lint_level=LintLevel.INFO)
result.dump()

This emits a warning similar to the following:

WARNING: Detected 1 dynamic shape(s):
  File "example.py", line 8, in f

For this kind of selection, prefer static-shape tensor operations such as torch.where(x > 0, x, y).

6.4.2.3.4. Assignment of torch.Tensor to variables outside function scope

pattern 1.

The following function cannot be correctly exported by FX2ONNX because FX2ONNX cannot track assignment of a torch.Tensor value to a variable outside the function scope:

import torch
from fx2onnx.linter import LintLevel, lint

a = None

def f(inp: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    global a
    x = inp["x"]
    a = x + 1
    return {"out": x}

sample = {"x": torch.zeros(1)}

result = lint(f, args=(sample,), lint_level=LintLevel.INFO)
result.dump()

This emits a warning similar to the following:

WARNING: Detected 1 tensor leak(s) to global/nonlocal variable(s):
  File "example.py", line 9, in f

If persistent tensor state is required, keep the destination as a torch.Tensor and update it with traceable tensor operations such as copy_() or add_().

6.4.2.3.5. Input and output type mismatch

MLSDK expects a compile target to receive exactly one positional argument of type dict[str, torch.Tensor] and return dict[str, torch.Tensor]. The following function is therefore reported by the input/output checker:

import torch
from fx2onnx.linter import LintLevel, lint

def f(
    inp: dict[str, object],
    extra: torch.Tensor,
    *,
    y: torch.Tensor,
) -> dict[str, object]:
    return {"out": extra + y, "z": inp["x"] + 1}

args = ({"x": 1}, torch.zeros(1))
kwargs = {"y": torch.zeros(1)}

result = lint(f, args=args, kwargs=kwargs, lint_level=LintLevel.INFO)
result.dump()

This emits warnings similar to the following:

WARNING: Input/output types do not match the expected signature `(args=(dict[str, torch.Tensor],), kwargs={}) -> dict[str, torch.Tensor]`:
  Expected exactly one positional argument, but got 2.

  Expected args[0]['x'] to be torch.Tensor, but got int.

  Expected no keyword arguments, but got 'y'.

  Expected output['z'] to be torch.Tensor, but got int.