我不明白线labels.size(0)
。我是Pytorch的新手,对数据结构非常困惑。
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))`
回答你的问题
在PyTorch中,tensor.size()
允许您检查张量的形状。
在你的代码中,
images, labels = data
images
和labels
将各自包含N
数量的培训示例取决于您的批量大小。如果你看看标签的形状,它应该是[N, 1]
,其中N是小批量训练示例的大小。
对于那些刚接触神经网络的人来说,这是一种先见之明。
在训练神经网络时,从业者将转发数据集通过网络并优化梯度。
假设您的训练数据集包含100万张图像,您的训练脚本的设计方式是在一个时期内传递所有100万张图像。这种方法的问题是您需要很长时间才能从神经网络接收反馈。这是小批量培训的用武之地。
在PyTorch中,DataLoader类允许我们将数据集拆分为多个批次。如果您的训练装载器包含1百万个示例且批量大小为1000,那么您将期望每个时期将遍历所有小批量的1000步。这样,您可以更好地观察和优化训练表现。
labels
是尺寸为[N, 1]
的张量,其中N
等于批次中的样本数量。 .size(...)
返回具有Tensor维度的元组(torch.Size
)的子类,.size(0)
返回一个整数,其值为第一个(从0开始)维度(即N
)。