如何减少3D张量中的均值忽略零行[Pytorch]

问题描述 投票:0回答:1

我有一个 3D 张量 A 的形状为

[batch_size, N, dim]
,3D 张量 B 的形状为
[batch_size, N, 2]
。其中:

A 有一些填充行要填充到 N (这不是零向量,因为它已经传递到了一些函数中)。要知道哪一行被填充,我必须查找张量 B,如果行 idx k 被填充,则张量 B 中第 k 行的值为 [0, 0]。在计算平均值之前,我需要过滤掉 A 中的这些填充行。

减少 A 的均值后,结果具有形状

[batch_size, dim]

谢谢!

编辑:我想出了一个解决方案。欢迎任何其他解决方案!

# squeeze to 2D
A = A.view(-1, A.size(-1))
B = B.view(-1, B.size(-1))

# mask out padded row
mask = torch.sum(B, dim=-1)
mask[mask!=0] = 1

# the point is I need to assign padded row in A by zero
A = mask.unsqueeze(-1)*A

# calculate the real size of each cluster. 
mask = mask.view(batch_size, -1)
batch_cluster_size = torch.sum(mask, dim=-1, keepdim=True)

# convert back to original shape
A = A.view(batch_size, -1, A.size(-1))

output = torch.sum(A, dim=1)/batch_cluster_size

python pytorch tensor
1个回答
0
投票

稍微更紧凑的版本使用

torch.any
检查行是否已填充并避免将张量广播到不同的形状:

>>> mask = B.any(dim=-1)                 # (bs, n)
>>> A *= mask[...,None].float()          # (bs, n, dim)
>>> output = A.sum(1)/mask.sum(1, True)  # (bs, dim)
© www.soinside.com 2019 - 2024. All rights reserved.