我有一个张量pred
,它的.size()
为[4, 53, 161]
。我有另一个张量mask
,它的.size()
为[4, 53]
。
mask
仅是0和1。我想做的是获取pred
的值,其中mask
的值为1
。您会注意到pred
的尺寸比mask
大。我该怎么做?
mask = mask.unsqueeze(2)
new_pred = pred * mask
这将增加额外的尺寸。现在是[4, 53, 1]
。休息会照顾广播。 (如果您进行某些操作)
假设您有一个形状为[4, 53, 164]
的张量,现在您想将其简化为[4, 53]
,那么您可以应用像new_pred.mean(2)
这样的算术运算。