我有一个 3D 张量 A 的形状为
[batch_size, N, dim]
,3D 张量 B 的形状为 [batch_size, N, 2]
。其中:
A 有一些填充行要填充到 N (这不是零向量,因为它已经传递到了一些函数中)。要知道哪一行被填充,我必须查找张量 B,如果行 idx k 被填充,则张量 B 中第 k 行的值为 [0, 0]。在计算平均值之前,我需要过滤掉 A 中的这些填充行。
减少 A 的均值后,结果具有形状
[batch_size, dim]
。
谢谢!
编辑:我想出了一个解决方案。欢迎任何其他解决方案!
# squeeze to 2D
A = A.view(-1, A.size(-1))
B = B.view(-1, B.size(-1))
# mask out padded row
mask = torch.sum(B, dim=-1)
mask[mask!=0] = 1
# the point is I need to assign padded row in A by zero
A = mask.unsqueeze(-1)*A
# calculate the real size of each cluster.
mask = mask.view(batch_size, -1)
batch_cluster_size = torch.sum(mask, dim=-1, keepdim=True)
# convert back to original shape
A = A.view(batch_size, -1, A.size(-1))
output = torch.sum(A, dim=1)/batch_cluster_size
torch.any
检查行是否已填充并避免将张量广播到不同的形状:
>>> mask = B.any(dim=-1) # (bs, n)
>>> A *= mask[...,None].float() # (bs, n, dim)
>>> output = A.sum(1)/mask.sum(1, True) # (bs, dim)