在 pytorch 中布尔掩码指定的位置平均张量值的有效方法

问题描述 投票:0回答:0

我有一个形状为 (batch, channel, N) 的张量

x
和一个形状为 (M, N) 的张量
masks
,其中
masks[i]
是长度为 N 的布尔掩码。

对于

masks
中的每个条目,我想取
x
的平均值,由
masks[i]
掩盖,即:

out = [torch.mean(x[:, :, masks[i]], -1) for i in range(len(masks))]

但是这个循环非常慢。有什么办法可以在 pytorch 中一次完成所有操作吗? 似乎

x[:, :, masks]
不起作用,因为
masks
是面具列表。

注意,每个掩码都有不同数量的 True 条目,因此简单地从

x
中切出相关元素并进行平均是很困难的,因为它会导致嵌套/参差不齐的张量。

我尝试了一种使用非常大的掩码张量的解决方案,例如

x_masked = masked_tensor(x[:, :, None, :].repeat((1, 1, M, 1)), 
                         masks[None, None, :, :].repeat((b, c, 1, 1)))
out = torch.mean(x_masked, -1).get_data()

虽然这快如闪电,但它会产生非常大的张量并且无法使用。这很不幸,因为我什至不需要访问巨大张量中的大部分元素。

感谢任何帮助! 谢谢

tensorflow pytorch torch numpy-slicing
© www.soinside.com 2019 - 2024. All rights reserved.