我是Pytorch的新手,我试着写我的训练班,但是我有这个错误
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import tqdm
class MLPNet(nn.Module):
def __init__(self):
super(MLPNet, self).__init__()
self.first_fully_connected = nn.Linear(8*8, 100)
self.last_fully_connected = nn.Linear(100, 10)
def forward(self, x):
x = x.view(-1, 8*8) # reshape input tensor to the size (batch_size, 8*8)
x = F.sigmoid(self.first_fully_connected(x))
x = F.sigmoid(self.last_fully_connected(x))
return x
def training(mlp, X, y, epochs=1, lr=.2, batch_size=101):
# solver
# loss
solver = torch.optim.SGD(mlp.parameters(), lr=lr, momentum=0.9)
loss = nn.CrossEntropyLoss() # nn.NLLLoss()
n_batches = (len(X) + batch_size - 1) // batch_size
for epoch in tqdm.tqdm(range(epochs)):
for i in range(n_batches):
slice_ = np.s_[i::n_batches]
X_batch = Variable(torch.from_numpy(X[slice_])).float()
y_batch = Variable(torch.from_numpy(y[slice_, np.newaxis])).float()
# X_batch = Variable(torch.from_numpy(X[slice_])).long()
# y_batch = Variable(torch.from_numpy(y[slice_, np.newaxis])).long()
print(type(X_batch.data))
print(type(y_batch.data))
### BEGIN: your optim step here. do not forget to reset gradients
# Clear gradients w.r.t. parameters
solver.zero_grad()
prediction = mlp(X_batch)
# Forward pass to get output/logits
#outputs = mlp(X_batch)
# Calculate Loss: softmax --> cross entropy loss
#loss = criterion(outputs, y_batch)
loss_f = loss(prediction, y_batch)
# Getting gradients w.r.t. parameters
loss_f.backward()
# Updating parameters
solver.step()
### END
return mlp
mlp = nn.Sequential(
#### Your net here
nn.Linear(2, 64),
nn.ReLU(),
nn.Linear(64, 2)
)
model_mlp = training(mlp, X_std, y_std)
但我得到了这个错误,我尝试了改变类型,但仍面临着这个错误。我试过改变了损失功能,但仍然是那个错误。
()----> 1 model_mlp = fit(mlp,X_std,y_std)中的RuntimeError Traceback(最近一次调用last)
RuntimeError:Variable [torch.FloatTensor]类型的预期对象,但为参数#1找到类型Variable [torch.LongTensor]'mat1'
我非常感谢您提供的任何帮助。非常感谢
正如@Ioannis Nasios在问题中编辑的那样,你有
X_batch = Variable(torch.from_numpy(X[slice_])).long()
这意味着你的MLP输入张量是long
整数,但网络需要浮点数。所以你需要:
X_batch = Variable(torch.from_numpy(X[slice_])).float()
这应该可以解决你的错误。