Pytorch 中的“unsqueeze”有什么作用?

问题描述 投票:0回答:5

PyTorch 文档 说:

返回一个新的张量,其尺寸为1插入到指定位置。 [...]

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])
python numpy machine-learning deep-learning pytorch
5个回答
150
投票

unsqueeze
变成了nd。张量转换为 (n+1).d。通过添加深度为 1 的额外维度。但是,由于新维度应位于哪个轴(即应“未压缩”的方向)不明确,因此需要通过
dim
参数指定.

例如

unsqueeze
可以通过三种不同的方式应用于二维张量:

生成的未压缩张量具有相同的信息,但用于访问它们的索引不同。


89
投票

如果查看数组前后的形状,您会发现之前是

(4,)
,之后是
(1, 4)
(当第二个参数为
0
时)和
(4, 1)
(当第二个参数为
1 时) 
)。因此,在轴
1
0
处将
1
插入到数组的形状中,具体取决于第二个参数的值。

这与

np.squeeze()
(借用自 MATLAB 的术语)相反,它删除了尺寸为
1
(单例)的轴。


43
投票

表示添加尺寸的位置。

torch.unsqueeze
为张量添加了一个额外的维度。

假设你有一个形状为 (3) 的张量,如果你在 0 位置添加一个维度,它将是形状 (1,3),这意味着 1 行和 3 列:

  • 如果你有一个形状为 (2,2) 的 2D 张量,请在 0 位置添加一个额外的维度,这将导致张量的形状为 (1,2,2),这意味着一个通道,2行和 2 列。如果在 1 位置添加,它将是形状 (2,1,2),因此它将有 2 个通道、1 行和 2 列。
  • 如果在1位置添加的话,就是(3,1),即3行1列。
  • 如果将其添加到2位置,则张量的形状将为(2,2,1),这意味着2个通道,2行和1列。

38
投票

以下是 PyTorch 文档中的描述:

torch.squeeze(input, dim=None, *, out=None)
→ 张量

返回一个张量,其中删除了

input
的所有尺寸(大小为 1)。

例如,如果输入的形状为:(A×1×B×C×1×D),则输出张量的形状为:(A×B×C×D)。

当给出

dim
时,仅在给定维度上进行挤压操作。如果 input 的形状为: (A×1×B) ,则
squeeze(input, 0)
保持张量不变,但
squeeze(input, 1)
会将张量压缩为 (A×B) 形状。

torch.unsqueeze(input, dim)
→ 张量

返回一个新的张量,其尺寸为1插入到指定位置。

返回的张量与该张量共享相同的基础数据。

可以使用

dim

 范围内的 
[-input.dim() - 1, input.dim() + 1)
 值。负 
dim
 将对应于在 
unsqueeze()
 应用的 
dim = dim + input.dim() + 1


8
投票

unsqueeze是一种改变张量维度的方法,使得诸如张量乘法之类的操作成为可能。这基本上改变了维度以产生具有不同维度的张量。

例如:如果你想将大小为 (4) 的张量与大小为

(4, N, N) 的张量相乘,那么你会得到一个错误。 但使用 unsqueeze 方法,您可以将张量转换为大小 (4,1,1)。现在,由于它的操作数大小为 1,因此您将能够将两个张量相乘。

© www.soinside.com 2019 - 2024. All rights reserved.