我有两个数组,
x
和y
,具有相同的形状。 x
代表数据,y
代表x
中的每个数据点属于哪个类。我想创建一个新的张量,其中 x
中的数据根据 y
中的类划分为通道。
如果我使用 one-hot 编码,我就可以实现这一点。然而,对于大型张量(尤其是具有大量类),PyTorch 的 one-hot 编码很快就会耗尽 GPU 上的所有内存。
是否有更节省内存的方法来进行此广播?
import torch
B, C, N = 2, 10, 1000
x = torch.randn(B, 1, N)
y = torch.randint(low=0, high=C, size=(B, 1, N))
one_hot = torch.nn.functional.one_hot(y, C) # B 1 N C
one_hot = one_hot.squeeze().permute(0, -1, 1) # B C N
z = x * one_hot # B C N
如果
z
是所需的输出张量,那么您必须以某种方式在内存中分配 BxCxN
。另一种解决方案是扩展 x
和 y
并将值分散到零张量中:
>>> z = torch.zeros(B,C,N).scatter_(1,y.expand(-1,C,-1),x.expand(-1,C,-1))