检查一个张量值是否包含在另一个张量中

问题描述 投票:0回答:1

我有一个像这样的火炬张量:

a=[1, 234, 54, 6543, 55, 776]

以及其他张量,如下所示:

b=[234, 54]
c=[55, 776]

我想创建一个新的掩模张量,如果有另一个张量(

a
b
)等于它,则
c
的值将为真。

例如,在上面的张量中,我想创建以下掩蔽张量:

a_masked =[False, True, True, False, True, True]
# The first two True values correspond to tensor `b` while the last two True values 
correspond to tensor `c`.

我见过其他方法来检查完整张量是否包含在另一个张量中,但这里不是这种情况。

有没有一种火炬方式可以有效地做到这一点? 谢谢!

python pytorch torch
1个回答
0
投票

根据 PyTorch 论坛 here 上的答案,看起来您只需要一个显式的 for 循环,例如,

import torch

a = torch.tensor([1, 234, 54, 6543, 55, 776])
b = torch.tensor([234, 54])

a_masked = sum(a == i for i in b).bool()

print(a_masked)
tensor([False,  True,  True, False, False, False])
© www.soinside.com 2019 - 2024. All rights reserved.