Pytorch多类Logistic回归类型误差。

问题描述 投票:0回答:1

我是新来的ML,对Pytorch更是幼稚。问题是这样的。(我跳过了某些部分,比如随机拆分(),它看起来工作得很好)

我需要预测葡萄酒的质量(红葡萄酒),从数据集上看,最后一列有6个类。

我的数据集是这样的

数据集的链接(winequality-red.csv)

features = df.drop(['quality'], axis = 1)
targets = df.iloc[:, -1] # theres 6 classes

dataset = TensorDataset(torch.Tensor(np.array(features)).float(), torch.Tensor(targets).float())
# here's where I think the error might be, but I might be wrong

batch_size = 8
# Dataloader

train_loader = DataLoader(train_ds, batch_size, shuffle = True)
val_loader = DataLoader(val_ds, batch_size)
test_ds = DataLoader(test_ds, batch_size)

input_size = len(df.columns) - 1
output_size = 6
threshold = .5

class WineModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.linear = nn.Linear(input_size, output_size)

  def forward(self, xb):
    out = self.linear(xb)
    return out

model = WineModel()
n_iters = 2000
num_epochs = n_iters / (len(train_ds) / batch_size)
num_epochs = int(num_epochs)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

# the part below returns the error on running
iter = 0 
for epoch in range(num_epochs):
  for i, (x, y) in enumerate(train_loader):
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

RuntimeError: 预期的标量类型为Long,但找到了Float。

希望这是足够的信息

machine-learning pytorch logistic-regression
1个回答
0
投票

目标是 nn.CrossEntropyLoss 给出的是类的索引,要求是整数,准确地说,它们的类型是 torch.long这相当于 torch.int64.

你将目标转换为浮动,但你应该将它们转换为长线。

dataset = TensorDataset(torch.Tensor(np.array(features)).float(), torch.Tensor(targets).long())

因为target是类的指数, 它们必须在范围内。[0,num_classes - 1]. 由于你有6个班级,将在[0,5]范围内。快速查看了一下你的数据,质量使用的值范围是[3,8]。虽然你有6个类,但值不能直接作为类使用。如果你把类列为 classes = [3, 4, 5, 6, 7, 8],你可以看到,第一类是3。classes[0] == 3,直到最后一个班级是 classes[5] == 8.

你需要用索引来替换类的值,就像对命名类一样(例如,如果你的类是 猫咪, 将是0和 猫咪 会是1),但你可以不用去查,因为这些值只是简单地移位3,即。index = classes[index] - 3. 因此你可以从整个目标张量中减去3。

torch.Tensor(targets).long() - 3
© www.soinside.com 2019 - 2024. All rights reserved.