如何减少 3D 张量中的均值忽略填充行

问题描述 投票:0回答:1

我有一个 3D 张量

A
的形状为
(batch_size, N, dim)
,3D 张量
B
的形状为
(batch_size, N, 2)
。其中
B
有一些填充行来填充 N (这不是零向量,因为它已经传入了一些函数)。要知道填充了哪一行,我必须查找张量
B
,如果行
k
被填充,则张量
k
B
行的值为
[0, 0]
。我想在计算平均值之前过滤掉
A
中的这些填充行。

沿

dim=1
将 A 减少到其平均值后,结果的形状为
(batch_size, dim)

编辑:我想出了一个解决方案。欢迎任何其他解决方案!

# squeeze to 2D
A = A.view(-1, A.size(-1))
B = B.view(-1, B.size(-1))

# mask out padded row
mask = torch.sum(B, dim=-1)
mask[mask!=0] = 1

# the point is I need to assign padded row in A by zero
A = mask.unsqueeze(-1)*A

# calculate the real size of each cluster. 
mask = mask.view(batch_size, -1)
batch_cluster_size = torch.sum(mask, dim=-1, keepdim=True)

# convert back to original shape
A = A.view(batch_size, -1, A.size(-1))

output = torch.sum(A, dim=1)/batch_cluster_size

python pytorch tensor
1个回答
0
投票

稍微更紧凑的版本使用

torch.all
检查行是否已填充并避免将张量广播到不同的形状:

>>> mask = B.all(dim=-1)                 # (bs, N)
>>> A *= mask.unsqueeze(-1)              # (bs, N, dim)
>>> output = A.sum(1)/mask.sum(1, True)  # (bs, dim)
© www.soinside.com 2019 - 2024. All rights reserved.