顺序容器采用元组输入时 PyTorch JIT 脚本错误

问题描述 投票:0回答:0

顺序容器采用元组输入时 PyTorch JIT 脚本错误。PyTorch

这是一个简单的网络来重现我的错误。我将元组传递给转发方法并指定了类型。我认为该错误是由于 Jit 将 Sequential 的 forward 方法的输入类型推断为 Tensor 而不是 Tuple 引起的。我该如何解决这个错误?

class MyBatchNorm(nn.Module):
    def __init__(self, output_size, d_ids):
        super().__init__()
        self.d_ids = d_ids
        self.net = nn.ModuleDict({f"{d}": nn.BatchNorm1d(output_size) for d in d_ids})
    
    def forward(self, input_tuple: Tuple[torch.Tensor, int]) -> Tuple[torch.Tensor, int]:
        input_tensor, d = input_tuple
        output_tensor = torch.tensor([])
        for d_name, d_norm in self.net.items():
            if f"{d}" == d_name:
                output_tensor = d_norm(input_tensor)
        if len(output_tensor) == 0:
            raise ValueError(f"invalid d {d}, must be {self.d_ids}")
        return output_tensor, d

class MyNet(nn.Module):
    def __init__(self, output_size, d_ids):
        super().__init__()
        dense_layers = [
            MyBatchNorm(output_size, d_ids),
            MyBatchNorm(output_size, d_ids)
        ]
        self.net = torch.nn.Sequential(*dense_layers)
        
    def forward(self, input_tensor: torch.Tensor, d_tensor: torch.Tensor) -> torch.Tensor:
        d = d_tensor.squeeze()[0].item()
        output_tensor, _ = self.net((input_tensor, d))
        return torch.squeeze(output_tensor)

错误:

RuntimeError: 

forward(__torch__.___torch_mangle_16.MyBatchNorm self, (Tensor, int) input_tuple) -> ((Tensor, int)):
Expected a value of type 'Tuple[Tensor, int]' for argument 'input_tuple' but instead found type 'Tensor (inferred)'.
Inferred the value for argument 'input_tuple' to be of type 'Tensor' because it was not annotated with an explicit type.
:
  File "/home/ec2-user/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/container.py", line 117
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
jit torchscript
© www.soinside.com 2019 - 2024. All rights reserved.