基于另一个数组跨通道广播 pytorch 数组

问题描述 投票:0回答:1

我有两个数组,

x
y
,具有相同的形状。
x
代表数据,
y
代表
x
中的每个数据点属于哪个类。我想创建一个新的张量,其中
x
中的数据根据
y
中的类划分为通道。

如果我使用 one-hot 编码,我就可以实现这一点。然而,对于大型张量(尤其是具有大量类),PyTorch 的 one-hot 编码很快就会耗尽 GPU 上的所有内存。

是否有更节省内存的方法来进行此广播?

import torch

B, C, N = 2, 10, 1000

x = torch.randn(B, 1, N)
y = torch.randint(low=0, high=C, size=(B, 1, N))

one_hot = torch.nn.functional.one_hot(y, C)  # B 1 N C
one_hot = one_hot.squeeze().permute(0, -1, 1)  # B C N

z = x * one_hot  # B C N
python pytorch
1个回答
0
投票

如果

z
是所需的输出张量,那么您必须以某种方式在内存中分配
BxCxN
。另一种解决方案是扩展
x
y
并将值分散到零张量中:

>>> z = torch.zeros(B,C,N).scatter_(1,y.expand(-1,C,-1),x.expand(-1,C,-1))
© www.soinside.com 2019 - 2024. All rights reserved.