我有一个形状为 (batch, channel, N) 的张量
x
和一个形状为 (M, N) 的张量 masks
,其中 masks[i]
是长度为 N 的布尔掩码。
对于
masks
中的每个条目,我想取x
的平均值,由masks[i]
掩盖,即:
out = [torch.mean(x[:, :, masks[i]], -1) for i in range(len(masks))]
但是这个循环非常慢。有什么办法可以在 pytorch 中一次完成所有操作吗? 似乎
x[:, :, masks]
不起作用,因为 masks
是面具列表。
注意,每个掩码都有不同数量的 True 条目,因此简单地从
x
中切出相关元素并进行平均是很困难的,因为它会导致嵌套/参差不齐的张量。
我尝试了一种使用非常大的掩码张量的解决方案,例如
x_masked = masked_tensor(x[:, :, None, :].repeat((1, 1, M, 1)),
masks[None, None, :, :].repeat((b, c, 1, 1)))
out = torch.mean(x_masked, -1).get_data()
虽然这快如闪电,但它会产生非常大的张量并且无法使用。这很不幸,因为我什至不需要访问巨大张量中的大部分元素。
感谢任何帮助! 谢谢