如何使用PyTorch基于蒙版选择多维张量的所有值?

问题描述 投票:0回答:1

我有一个张量pred,它的.size()[4, 53, 161]。我有另一个张量mask,它的.size()[4, 53]

mask仅是0和1。我想做的是获取pred的值,其中mask的值为1。您会注意到pred的尺寸比mask大。我该怎么做?

python pytorch tensor
1个回答
1
投票
mask = mask.unsqueeze(2)
new_pred = pred * mask

这将增加额外的尺寸。现在是[4, 53, 1]。休息会照顾广播。 (如果您进行某些操作)

假设您有一个形状为[4, 53, 164]的张量,现在您想将其简化为[4, 53],那么您可以应用像new_pred.mean(2)这样的算术运算。

© www.soinside.com 2019 - 2024. All rights reserved.