我有一个张量,
tensor([[ 0.2213, -0.1180, 1.1186],
[-0.9943, -0.7679, -1.7057]])
现在我想根据条件给每个元素赋值,像这样:
torch.where(x < 0, 0, x)
torch.where(0 <= x <= 1, 5, x)
torch.where(2 <= x <= 3, 7, x)
当然,'where'不接受这样的条件。我怎样才能实现它? torch.isin 在这种情况下也无济于事。
第一个语句是有效的,因为
x < 0
产生一个布尔张量。
torch.logical_and
:
>>> torch.where(torch.logical_and(0 <= x, x <= 1), 5, x)
请注意,对于两个给定的布尔张量:
a
和 b
,torch.logical_and(a, b)
将等同于 a*b
,因为 *
是布尔值的与运算符。
>>> torch.where((0 <= x)*(x <= 1), 5, x)
tensor([[ 5.0000, -0.1180, 1.1186],
[-0.9943, -0.7679, -1.7057]])
>>> torch.where((2 <= x)*(x <= 3), 7, x)
tensor([[ 0.2213, -0.1180, 1.1186],
[-0.9943, -0.7679, -1.7057]])