在numpy中,我们使用ndarray.reshape()
重塑数组。
[我注意到在pytorch中,人们出于相同的目的使用torch.view(...)
,但同时也存在torch.reshape(...)
。
所以我想知道它们之间有什么区别以及何时应使用它们中的任何一个?
torch.view
已经存在很长时间了。它将返回具有新形状的张量。返回的张量将与原始张量共享基础数据。参见documentation here。
另一方面,似乎torch.reshape
has been introduced recently in version 0.4。根据document,此方法将
返回一个张量,该张量具有与输入相同的数据和元素数量,但具有指定的形状。如果可能,返回的张量将是输入视图。否则,它将是副本。连续输入和具有兼容步幅的输入可以在不复制的情况下进行重塑,但是您不应该依赖复制与查看行为。
这意味着torch.reshape
可能返回原始张量的副本或视图。您不能指望它返回视图或副本。根据开发商的说法:
如果需要复制,请使用clone(),如果需要相同的存储,请使用view()。 reshape()的语义是它可能共享或可能不共享存储,并且您事先不知道。
另一个区别是,reshape()
可以在连续和非连续张量上运行,而view()
仅可以在连续张量上运行。另请参见here,以了解contiguous
的含义。
尽管torch.view
和torch.reshape
都用于重整张量,但这是它们之间的差异。
torch.view
仅创建原始张量的view。新的张量将总是与原始张量共享其数据。这意味着,如果您更改原始张量,则重塑后的张量将会更改,反之亦然。>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
[1., 1., 1.]])
torch.view
对两个张量[docs]的形状施加了一些连续性约束。通常这不是一个问题,但是有时即使两个张量的形状兼容,torch.view
也会引发错误。这是一个著名的反例。>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
torch.reshape
不施加任何连续性约束,但是也不保证数据共享。新的张量可以是原始张量的视图,也可以是新的张量。>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])
TL; DR:如果只想重塑张量,请使用torch.reshape
。如果您还担心内存使用情况,并且要确保两个张量共享相同的数据,请使用torch.view
。
Tensor.reshape()
更可靠。它适用于任何张量,而Tensor.view()
仅适用于张量t
,其中t.is_contiguous()==True
。
解释不连续和连续是另一个时间故事,但始终可以将张量t
设为连续,只要调用t.contiguous()
,然后就可以调用view()
而不会出现错误。