Torch.onnx.export 使用位置和关键字参数的模块

问题描述 投票:0回答:1

我使用位置参数和关键字参数定义了一个带有forward(..) 函数的简单 nn.Module:

import torch
import torch.nn as nn

cuda0 = torch.device('cuda:0')
x = torch.tensor([[[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]]).to(device=cuda0)

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(1, 1, 2)
        ).to(device=cuda0)

    def forward(self, cond, **kwargs):
        if (cond):
            return self.net(kwargs['input'])
        else:
            return torch.tensor(0).to(device=cuda0)

module = MyModule()
module(torch.tensor(True).to(device=cuda0), **{'input': x})

接下来,我尝试将此模块导出到onnx:

torch.onnx.export(module,
                  args=(torch.tensor(True).to(device=cuda0), {'input': x}), 
                  f='sample.onnx', input_names=['input'], output_names=['output'], export_params=True)

但这会导致错误:

TypeError: forward() takes 2 positional arguments but 3 were given

我想,我正在根据文档这样做:

元组中除最后一个元素之外的所有元素都将作为非关键字传递 参数和命名参数将从最后一个元素开始设置。

https://pytorch.org/docs/stable/onnx.html

我做错了什么?

火炬1.8.0

pytorch export onnx
1个回答
0
投票

您可能需要将命名参数排列为包含在元组中的字典,如下所示:

参数 = ( X, { “y”:输入_y, “z”:输入_z } )

参考:https://pytorch.org/docs/stable/onnx_torchscript.html#module-torch.onnx

© www.soinside.com 2019 - 2024. All rights reserved.