我有一个 onnx 模型,它有一些(理想情况下)布尔输入,仅用于模型内的控制流。
我尝试做的一些最小代码:
import onnx
import onnxruntime
import torch.onnx
class SumModule(torch.nn.Module):
def forward(self, x1, x2):
if x2 is not None:
x1 *= 1
return torch.sum(x1)
torch_model = SumModule()
torch_model.eval()
model_inputs = {'x1': torch.tensor([1, 2]), 'x2': torch.tensor([1, 2])}
torch_out = torch_model(**model_inputs)
torch.onnx.export(torch_model,
tuple(model_inputs.values()),
'model.onnx',
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names=list(model_inputs.keys()),
output_names=['output'],
dynamic_axes={'x1': {0: 'batch_size'}, })
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession('model.onnx')
def to_numpy(tensor):
if isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
return tensor
model_inputs_np = {k: to_numpy(v) for k, v in model_inputs.items()}
ort_outs = ort_session.run(None, input_feed=model_inputs_np)
当 onnx 导出完成时,我无法在没有错误的情况下运行推理模型
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:x2
我想我在这里误解了 onnx 的一些基本知识。
x2
的论点明明是存在的,为什么 onnx 会以某种方式丢弃它呢?如果我根本不使用参数x2
(但将其作为输入参数),也会发生同样的情况,我也觉得很奇怪。
在我的实际代码中,我想要做的控制流程如下: 我有 3 个可选的输入,所以理想情况下是
Optional[torch.Tensor]
。然而,onnx似乎无法处理None
。所以相反,我想要 3 个输入 + 3 个布尔标志(torch.tensor(True)
或者如果需要的话是 torch.tensor([True]) or replace
Truewith
1or
1.0` --> 所有这些都有同样的问题)。
然后在代码中我根据这些标志做不同的事情。
为什么 onnx 不允许这样做?我发现如果我有时以某种方式将它们包含在某些计算中,那么拥有这些变量就可以了,但我无法弄清楚所有这一切背后的规则。
您的问题与
torch.onnx.export
的工作方式有关。
在生成 ONNX 模型时,torch 使用给定的输入执行(跟踪)一次模块,同时跟踪所有执行的计算,然后将它们映射到相应的 ONNX Operators,最后简化图形。在您的情况下,值得注意的细节是所有控制流都被评估一次并且Python内置类型被评估为常量。所以代码
if x2 is not None:
x1 *= 1
return torch.sum(x1)
另存为
if True:
x1 *= 1
return torch.sum(x1)
并且当
torch.onnx.export
简化图表时,它会删除所有未使用的变量,包括x2
,因此你的错误。
如果您想在导出的模型中保留控制流,您需要 torch 来使用
torch.jit.script
而不是 torch.jit.trace
来评估您的模型。正如您已经指出的那样,ONNX 需要固定数量的张量作为输入,并且不接受“可选”参数。使用脚本导出模型是这样完成的
scripted_model = torch.jit.script(torch_model)
torch.onnx.export(scripted model, ...)
但是,这样你的模型仍然无法工作。我们注意到前向传递中的
if
语句是 Pythonic 比较,并不对张量本身进行操作。所以x2
在简化过程中仍然会被丢弃。将SumModel
更改为
class SumModule(torch.nn.Module):
def forward(self, x1, x2):
if torch.any(x2):
x1 *= 1
return torch.sum(x1)
将产生正确的图形,因为现在
x2
实际上是在操作。有了这个,您可以使用 x2
作为控制流的布尔标志。
强烈建议查看 torch 文档,因为它解释了很多关于导出的常见错误。
编辑
为了完整起见,我应该补充一点,通常应该避免使用上述方法。许多硬件加速不是为条件设计的,并且尝试运行包含大量控制流的 ONNX 模型,例如 CUDA,通常会导致大部分图形回落到 CPU。当遇到这个问题中描述的情况时,我建议考虑
而不是使用上面介绍的解决方案