PyTorch 中的 .flatten() 和 .view(-1) 有什么区别?

问题描述 投票:0回答:3

.flatten()
.view(-1)
都可以在 PyTorch 中压平张量。有什么区别?

  1. .flatten()
    是否复制张量的数据?
  2. .view(-1)
    更快吗?
  3. 有没有什么情况
    .flatten()
    不起作用?
python machine-learning deep-learning pytorch
3个回答
20
投票

除了@adeelh的评论之外,还有另一个区别:

torch.flatten()
导致
.reshape()
.reshape()
.view()
之间的
区别是:

  • [...]

    torch.reshape
    可能会返回原始张量的副本或视图。您不能指望它会返回视图或副本。

  • 另一个区别是 reshape() 可以对连续张量和非连续张量进行操作,而 view() 只能对连续张量进行操作。另请参阅这里有关连续的含义

上下文:

  • 社区一度要求提供

    flatten
    功能,在Issue #7743之后,该功能在PR #8578中实现。

  • 您可以在这里看到展平的实现,其中可以在

    .reshape()
    行中看到对
    return
    的调用。


15
投票

flatten
只是 view 常见用例的
方便别名
1

还有其他几个:

功能 等价
view
逻辑
flatten()
view(-1)
flatten(start, end)
view(*t.shape[:start], -1, *t.shape[end+1:])
squeeze()
view(*[s for s in t.shape if s != 1])
unsqueeze(i)
view(*t.shape[:i-1], 1, *t.shape[i:])

请注意,

flatten
允许您使用
start_dim
end_dim
参数展平特定的连续维度子集。


  1. 实际上表面上相当于引擎盖下的
    reshape

2
投票

首先,

.view()
仅适用于连续数据,而
.flatten()
适用于连续非连续数据。像 transpose 这样的函数会生成 非连续数据,可以由
.flatten()
操作,但不能由
.view()
操作。

说到复制数据,
.view()
.flatten()
都不会复制数据。适用于连续数据。但是,如果是非连续数据
.flatten()
首先将数据复制到连续内存,然后更改维度。新张量的任何变化都不会影响原始张量。

 ten=torch.zeros(2,3)
 ten_view=ten.view(-1)
 ten_view[0]=123
 ten 

>>tensor([[123.,   0.,   0.],
           [  0.,   0.,   0.]])

 ten=torch.zeros(2,3)
 ten_flat=ten.flatten()
 ten_flat[0]=123
 ten

>>tensor([[123.,   0.,   0.],
        [  0.,   0.,   0.]])

在上面的代码中,张量ten具有连续的内存分配。对 ten_viewten_flat 的任何更改都会反映在张量 ten

ten=torch.zeros(2,3).transpose(0,1)
ten_flat=ten.flatten()
ten_flat[0]=123
ten

>>tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

在这种情况下,非连续转置张量 ten 用于 flatten()。对 ten_flat 所做的任何更改都不会反映在 ten 上。

© www.soinside.com 2019 - 2024. All rights reserved.