pytorch 神经网络的 int64 类型和 Long 类型的区别

问题描述 投票:0回答:1

我正在尝试训练 PyTorch 神经网络,但在前向传递时抛出以下错误:

RuntimeError: expected scalar type Long but found Float

尽管将输入数据的类型转换为 Long,但仍然会出现此错误。

数据类型转换为Long的自定义数据集类代码:

class MnistTrainDataset(Dataset):
    def __init__(self, df):
        self.X = torch.tensor(df.iloc[:, 1:].values, dtype=torch.long).reshape((-1, 1, 28, 28))
        self.y = torch.tensor(df.iloc[:, 0].values, dtype=torch.long).reshape((-1, 1))
        print(self.X[0])
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

当我包含

dtype=torch.long
时仍然出现错误,所以我也尝试通过
type(torch.LongTensor)
将整个张量更改为Long,但这并没有解决问题。

我在前向传递的第一步之前打印了输入数据的数据类型,它给出了这个输出:

torch.int64
,根据我读到的内容,它与 Long 类型相同。如果 int64 和 long 是相同的数据类型,我不确定为什么会引发错误。对其他类似帖子的回复主要是我以前尝试过的方式手动投射,但这些都没有奏效。我应该对 MnistTrain 类或正向传递进行哪些更改才能解决此问题?

转发密码:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        print(x.dtype)
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        print(x.dtype)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        print(x.dtype)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

完整错误:

2 for epoch in range(num_epochs):
      3     for X, y in train_iter:
----> 4         y_hat = net(X.type(torch.LongTensor))
      5         l = loss(y_hat.reshape(y.shape), y)
      7         optimiser.zero_grad()

File ~/opt/anaconda3/envs/d2l/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

/Users/atikshgupta/Desktop/kaggle/digit_mnist/mnist_classifier.ipynb Cell 5' in Net.forward(self, x)
     10 def forward(self, x):
     11     print(x.dtype)
---> 12     x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
     13     print(x.dtype)
     14     x = F.max_pool2d(F.relu(self.conv2(x)), 2)

File ~/opt/anaconda3/envs/d2l/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/opt/anaconda3/envs/d2l/lib/python3.10/site-packages/torch/nn/modules/conv.py:457, in Conv2d.forward(self, input)
    456 def forward(self, input: Tensor) -> Tensor:
--> 457     return self._conv_forward(input, self.weight, self.bias)

File ~/opt/anaconda3/envs/d2l/lib/python3.10/site-packages/torch/nn/modules/conv.py:453, in Conv2d._conv_forward(self, input, weight, bias)
    449 if self.padding_mode != 'zeros':
    450     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    451                     weight, bias, self.stride,
    452                     _pair(0), self.dilation, self.groups)
--> 453 return F.conv2d(input, weight, bias, self.stride,
    454                 self.padding, self.dilation, self.groups)

RuntimeError: expected scalar type Long but found Float`
pytorch neural-network casting
1个回答
0
投票

对数据类型转换进行以下更改,要求定标器类型 Long

data= data.type(torch.LongTensor)
© www.soinside.com 2019 - 2024. All rights reserved.