我正在尝试在pytorch中进行批处理。在下面的代码中,您可能会认为x
是批处理大小为2的批处理(每个样本都是10d向量)。我用x_sep
表示x
中的第一个样本。
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.fc1 = nn.Linear(10,10)
def forward(self, x):
x = self.fc1(x)
return x
f = net()
x = torch.randn(2,10)
x_sep = torch.zeros(1,10)
x_sep[0] = x[0]
y = f(x)
y_sep = f(x_sep)
print(y_sep[0]==y[0])
理想情况下,y_sep[0]==y[0]
应该给出具有所有真实项的张量。但是我计算机上的输出是
tensor([False, False, True, True, False, False, False, False, True, False])
为什么会这样?它是计算错误吗?还是与在pytorch中实现批处理的方式有关?
让我们逐行浏览您的代码:
建立网络结构后,初始化:
# and than x is a batch of size 2 of 10d vectors.
x = torch.randn(2,10)
# which means x_sep is 10d vector, but its actually an array
x_sep = torch.zeros(1,10) of zeros inside of an array (something like that)-> [[0*10]]
# In the next line the first element in x_sep array (i.e. the 10 zeros) are replaced by the first vector in x.
x_sep[0] = x[0]
# Because `f()` return a vector of 10d on every 10d vector it receive y now is 10*2 vectors as well (like x)
y = f(x)
# Because `f()` return a vector of 10d on every 10d vector it receive y_sep now is 10*1 because also **x_sep** is a vector 10*1
y_sep = f(x_sep)
因此,当您检查y_sep[0]==y[0]
时,您正在检查2个1 * 10的数组,结果是相同的。 1 * 10的布尔数组。