未使用或变量仅用于 onnx 模型中的控制流

问题描述 投票:0回答:1

我有一个 onnx 模型,它有一些(理想情况下)布尔输入,仅用于模型内的控制流。

我尝试做的一些最小代码:

import onnx
import onnxruntime
import torch.onnx


class SumModule(torch.nn.Module):
  def forward(self, x1, x2):
    if x2 is not None:
      x1 *= 1
    return torch.sum(x1)


torch_model = SumModule()
torch_model.eval()
model_inputs = {'x1': torch.tensor([1, 2]), 'x2': torch.tensor([1, 2])}

torch_out = torch_model(**model_inputs)
torch.onnx.export(torch_model,
                  tuple(model_inputs.values()),
                  'model.onnx',
                  export_params=True,
                  opset_version=16,
                  do_constant_folding=True,
                  input_names=list(model_inputs.keys()),
                  output_names=['output'],
                  dynamic_axes={'x1': {0: 'batch_size'}, })

onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)

ort_session = onnxruntime.InferenceSession('model.onnx')


def to_numpy(tensor):
  if isinstance(tensor, torch.Tensor):
    return tensor.detach().cpu().numpy()
  return tensor


model_inputs_np = {k: to_numpy(v) for k, v in model_inputs.items()}
ort_outs = ort_session.run(None, input_feed=model_inputs_np)

当 onnx 导出完成时,我无法在没有错误的情况下运行推理模型

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:x2

我想我在这里误解了 onnx 的一些基本知识。

x2
的论点明明是存在的,为什么 onnx 会以某种方式丢弃它呢?如果我根本不使用参数
x2
(但将其作为输入参数),也会发生同样的情况,我也觉得很奇怪。

在我的实际代码中,我想要做的控制流程如下: 我有 3 个可选的输入,所以理想情况下是

Optional[torch.Tensor]
。然而,onnx似乎无法处理
None
。所以相反,我想要 3 个输入 + 3 个布尔标志(
torch.tensor(True)
或者如果需要的话是
torch.tensor([True]) or replace 
True
with
1
or
1.0` --> 所有这些都有同样的问题)。 然后在代码中我根据这些标志做不同的事情。 为什么 onnx 不允许这样做?我发现如果我有时以某种方式将它们包含在某些计算中,那么拥有这些变量就可以了,但我无法弄清楚所有这一切背后的规则。

pytorch torch onnx
1个回答
0
投票

您的问题与

torch.onnx.export
的工作方式有关。

在生成 ONNX 模型时,torch 使用给定的输入执行(跟踪)一次模块,同时跟踪所有执行的计算,然后将它们映射到相应的 ONNX Operators,最后简化图形。在您的情况下,值得注意的细节是所有控制流都被评估一次并且Python内置类型被评估为常量。所以代码

if x2 is not None:
    x1 *= 1
return torch.sum(x1)

另存为

if True:
    x1 *= 1
return torch.sum(x1)

并且当

torch.onnx.export
简化图表时,它会删除所有未使用的变量,包括
x2
,因此你的错误。

如果您想在导出的模型中保留控制流,您需要 torch 来使用

torch.jit.script
而不是
torch.jit.trace
来评估您的模型。正如您已经指出的那样,ONNX 需要固定数量的张量作为输入,并且不接受“可选”参数。使用脚本导出模型是这样完成的

scripted_model = torch.jit.script(torch_model)
torch.onnx.export(scripted model, ...)

但是,这样你的模型仍然无法工作。我们注意到前向传递中的

if
语句是 Pythonic 比较,并不对张量本身进行操作。所以
x2
在简化过程中仍然会被丢弃。将
SumModel
更改为

class SumModule(torch.nn.Module):
  def forward(self, x1, x2):
    if torch.any(x2):
      x1 *= 1
    return torch.sum(x1)

将产生正确的图形,因为现在

x2
实际上是在操作。有了这个,您可以使用
x2
作为控制流的布尔标志。

强烈建议查看 torch 文档,因为它解释了很多关于导出的常见错误。

编辑

为了完整起见,我应该补充一点,通常应该避免使用上述方法。许多硬件加速不是为条件设计的,并且尝试运行包含大量控制流的 ONNX 模型,例如 CUDA,通常会导致大部分图形回落到 CPU。当遇到这个问题中描述的情况时,我建议考虑

  1. 如果您可以将“可选”张量输入为零而不影响结果
  2. 是否可以将模型完全拆分为不同的几个模型,并针对相应的输入运行每个模型

而不是使用上面介绍的解决方案

© www.soinside.com 2019 - 2024. All rights reserved.