使用多个输入张量跟踪火炬模块时出现问题

问题描述 投票:0回答:1

一些网络采用多个不同维度的张量作为输入。

使用 torch.jit.trace 似乎由于内部处理错误而失败。

这是一个最小的可重现示例:

import torch

class SimpleModel(torch.nn.Module):

    def __init__(self):
        super(SimpleModel, self).__init__()

        self.linear1 = torch.nn.Linear(100, 200)
        self.linear2 = torch.nn.Linear(50, 200)        
        self.activation = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax()

    def forward(self, t):
        x1, x2 = t
        x = self.linear1(x1) + self.linear2(x2)
        x = self.activation(x)
        x = self.softmax(x)
        return x

model = SimpleModel()

sample_input = (torch.rand(1, 100), torch.rand(1, 50))

# This works as intended
output = model(sample_input)

# This breaks
traced_model = torch.jit.trace(model, sample_input)

使用 torch==1.9.0 会产生以下错误消息:

/home/user/test_error.py:17: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  x = self.softmax(x)
Traceback (most recent call last):
  File "/home/user/test_error.py", line 28, in <module>
    traced_model = torch.jit.trace(model, sample_input)
  File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/jit/_trace.py", line 735, in trace
    return trace_module(
  File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/jit/_trace.py", line 952, in trace_module
    module._c._create_method_from_trace(
  File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/antonio/Sources/unicc/face_detection/poc_pytorch_mobile/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1039, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
torch torchscript
1个回答
0
投票

显然,只需将 Sample_input 包装在元组中即可解决问题:

traced_model = torch.jit.trace(model, (sample_input,))

如果模块采用单个张量作为输入,则不需要。

© www.soinside.com 2019 - 2024. All rights reserved.