鉴于:
现在我想用 C 替换索引 B 处的值。例如:A[B] = C
除了某些索引之外,我想用索引 B 处的 C 替换 A 中的值。例如。 B = [[0, 1, 2, 3], [0, 1, 2, 4]] 将索引 B 处的 A 值替换为 C,索引 3 和 5 除外(适用于所有行) -> 现在 B 可以' t 用张量表示,因为它在滤波器后没有相等的暗淡。像这样的:A[B[valid_indices]] = C[valid_indices]
我尝试使用 for 循环,但它花费了 2 个内循环并且花费了太长的时间。
for i,row in enumerate(probs):
valid_indices = torch.tensor([idx[0] for idx in enumerate(encoder_input_ids[i]) if idx[1] not in [vocab['<pad>'],vocab['<unk>'], vocab['</s>']]])
valid_ids = torch.tensor([idx[0] for idx in enumerate(encoder_input_ids[i]) if idx[1] not in [vocab['<pad>'],vocab['<unk>'], vocab['</s>']]])
# print(valid_ids)
# value = probs_c[i][valid_indices]
# probs[i][tmp] = value #probs_c[i]
这是我在您删除并稍后编辑您的帖子之前的最初回答。它回答了如何在条件
A
下使用
B
中的值对C
与M
进行索引。
torch.scatter
和 torch.where
来实现此目的:
我们从最小的设置开始:
>>> A = torch.rand(2,8)
>>> B = torch.randint(0, A.size(-1), size=(len(A), 5))
>>> C = torch.rand_like(B.float())
>>> M = torch.ones_like(A).bool()
按照您的示例“除了所有行的索引 3 和 5”,我们设置了
M[:,3]
和 M[:,5] = False
。然后我们可以使用 torch.scatter
执行索引(我们不能直接在 scatter_
上进行 A
,因为我们不知道特定索引是否有效。所以我们改为不恰当地执行:
>>> O = torch.zeros_like(A).scatter_(dim=1, index=B, src=C)
张量
O
充当缓冲区 ie. 就好像所有行都是有效的。上面的行相当于伪代码中的O[j, B[b,j]] = C
。然后,您可以使用 A
将其与基于掩码 M
的初始张量
torch.where
结合起来:
>>> O.where(M, A) # equivalent to torch.where(M, O, A)
使用 PyTorch 使用张量索引:
import torch
A = torch.tensor([[0.6301, 0.2310, -1.1964, -1.1293, -1.0428, 0.4011, 0.0519, -3.0591],
[1.3691, -0.7477, -0.9323, -0.3670, -0.1568, -0.1282, -0.4557, 0.1747]])
B = torch.tensor([[0, 2, 7, 7, 5],
[2, 4, 7, 2, 5]])
output = A[torch.arange(B.size(0)).unsqueeze(1), B]
print(output)