我的训练循环中出现以下错误,我不太明白问题是什么。我目前正在编写这段代码,所以东西不是最终的,但我无法弄清楚这个问题是什么。
我尝试用谷歌搜索错误并阅读一些答案,但似乎仍然无法理解问题的症结。
数据集和数据加载器 (X和Y已经给我了,它们都是[2000, 40, 1]张量)
class TrainingDataset(data.Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return Nf
# returns corresponding input/output pairs
def __getitem__(self, t):
X = self.X[t]
y = self.y[t]
#print(X.shape, y.shape)
return X, y
# prints torch.Size([2000, 40, 1]) torch.Size([2000, 40, 1])
print(x.size(), y.size())
dataset = TrainingDataset(x,y)
batchSize = 20
dataIter = data.DataLoader(dataset, batchSize)
型号:
class Encoder(nn.Module):
def __init__(self, num_inputs = 40, num_outputs = 40):
super(Encoder, self).__init__()
self.num_inputs = num_inputs
self.num_hidden = num_hidden
self.num_outputs = num_outputs
self.layers = nn.Sequential(
nn.Linear(num_inputs, num_outputs),
nn.ReLU(),
nn.Linear(num_outputs, num_outputs),
nn.ReLU(),
nn.Linear(num_outputs, num_outputs)
)
def forward(self, x_c, y_c):
return self.layers(x_c, y_c)
训练循环:
for epoch in range(epochs):
for batch in dataIter:
optimiser.zero_grad()
l = loss(encoder(x_c=batch[0], y_c=batch[1]), batch[1])
l.backward()
optimiser.step()
错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-15-aa1c60616d82> in <module>()
6 for batch in dataIter:
7 optimiser.zero_grad()
----> 8 l = loss(encoder(x_c=batch[0], y_c=batch[1]), batch[1])
9 l.backward()
10 optimiser.step()
2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
TypeError: forward() takes 2 positional arguments but 3 were given
有人能指出我正确的方向吗?我刚刚开始学习和做 pytorch 所以我还不擅长这些。
def forward(self, x_c, y_c):
return self.layers(x_c, y_c)
你的错误就在这里,这个函数除了
self
之外应该只有 1 个参数。
我有一个类似的问题并用这个解决了:
class SequentialDecoder(nn.Sequential):
def forward(self, *inputs):
x, y = inputs
for module in self._modules.values():
x = module(x, y)
return x