.flatten()
和 .view(-1)
都可以在 PyTorch 中压平张量。有什么区别?
.flatten()
是否复制张量的数据?.view(-1)
更快吗?.flatten()
不起作用?除了@adeelh的评论之外,还有另一个区别:
torch.flatten()
导致.reshape()
,.reshape()
和
.view()
之间的区别是:
[...]
可能会返回原始张量的副本或视图。您不能指望它会返回视图或副本。torch.reshape
另一个区别是 reshape() 可以对连续张量和非连续张量进行操作,而 view() 只能对连续张量进行操作。另请参阅这里有关连续的含义
上下文:
社区一度要求提供
flatten
功能,在Issue #7743之后,该功能在PR #8578中实现。
您可以在这里看到展平的实现,其中可以在
.reshape()
行中看到对return
的调用。
flatten
只是 view
常见用例的 方便别名。1
还有其他几个:
功能 | 等价 逻辑 |
---|---|
|
|
|
|
|
|
|
|
请注意,
flatten
允许您使用 start_dim
和 end_dim
参数展平特定的连续维度子集。
reshape
。首先,
.view()
仅适用于连续数据,而.flatten()
适用于连续和非连续数据。像 transpose 这样的函数会生成 非连续数据,可以由 .flatten()
操作,但不能由 .view()
操作。.view()
和 .flatten()
都不会复制数据。适用于连续数据。但是,如果是非连续数据,.flatten()
首先将数据复制到连续内存,然后更改维度。新张量的任何变化都不会影响原始张量。 ten=torch.zeros(2,3)
ten_view=ten.view(-1)
ten_view[0]=123
ten
>>tensor([[123., 0., 0.],
[ 0., 0., 0.]])
ten=torch.zeros(2,3)
ten_flat=ten.flatten()
ten_flat[0]=123
ten
>>tensor([[123., 0., 0.],
[ 0., 0., 0.]])
在上面的代码中,张量ten具有连续的内存分配。对 ten_view 或 ten_flat 的任何更改都会反映在张量 ten
上ten=torch.zeros(2,3).transpose(0,1)
ten_flat=ten.flatten()
ten_flat[0]=123
ten
>>tensor([[0., 0.],
[0., 0.],
[0., 0.]])
在这种情况下,非连续转置张量 ten 用于 flatten()。对 ten_flat 所做的任何更改都不会反映在 ten 上。