我有一个 3D 张量
A
的形状为 (batch_size, N, dim)
,3D 张量 B
的形状为 (batch_size, N, 2)
。其中 B
有一些填充行来填充 N (这不是零向量,因为它已经传入了一些函数)。要知道填充了哪一行,我必须查找张量B
,如果行k
被填充,则张量k
中B
行的值为[0, 0]
。我想在计算平均值之前过滤掉 A
中的这些填充行。
沿
dim=1
将 A 减少到其平均值后,结果的形状为 (batch_size, dim)
。
编辑:我想出了一个解决方案。欢迎任何其他解决方案!
# squeeze to 2D
A = A.view(-1, A.size(-1))
B = B.view(-1, B.size(-1))
# mask out padded row
mask = torch.sum(B, dim=-1)
mask[mask!=0] = 1
# the point is I need to assign padded row in A by zero
A = mask.unsqueeze(-1)*A
# calculate the real size of each cluster.
mask = mask.view(batch_size, -1)
batch_cluster_size = torch.sum(mask, dim=-1, keepdim=True)
# convert back to original shape
A = A.view(batch_size, -1, A.size(-1))
output = torch.sum(A, dim=1)/batch_cluster_size
torch.all
检查行是否已填充并避免将张量广播到不同的形状:
>>> mask = B.all(dim=-1) # (bs, N)
>>> A *= mask.unsqueeze(-1) # (bs, N, dim)
>>> output = A.sum(1)/mask.sum(1, True) # (bs, dim)