如何通过索引张量将一个值张量分配给另一个值张量(Pytorch)

问题描述 投票:0回答:2

鉴于:

  • 零张量 A 的形状为: (batch_size, vocab_size),比如说 (16, 10000)
  • 张量 B 是索引张量,其形状为: (batch_size, seq_len),可以说 (16, 20)
  • 张量 C 是值的张量,其形状为: (batch_size, seq_len), -> (16, 20)
  1. 现在我想用 C 替换索引 B 处的值。例如:A[B] = C

  2. 除了某些索引之外,我想用索引 B 处的 C 替换 A 中的值。例如。 B = [[0, 1, 2, 3], [0, 1, 2, 4]] 将索引 B 处的 A 值替换为 C,索引 3 和 5 除外(适用于所有行) -> 现在 B 可以' t 用张量表示,因为它在滤波器后没有相等的暗淡。像这样的:A[B[valid_indices]] = C[valid_indices]

我尝试使用 for 循环,但它花费了 2 个内循环并且花费了太长的时间。

for i,row in enumerate(probs): 
            valid_indices = torch.tensor([idx[0] for idx in enumerate(encoder_input_ids[i]) if idx[1] not in [vocab['<pad>'],vocab['<unk>'], vocab['</s>']]])
            valid_ids = torch.tensor([idx[0] for idx in enumerate(encoder_input_ids[i]) if idx[1] not in [vocab['<pad>'],vocab['<unk>'], vocab['</s>']]])
            # print(valid_ids)
            # value = probs_c[i][valid_indices]
            # probs[i][tmp] = value #probs_c[i]
python pytorch tensor
2个回答
0
投票

这是我在您删除并稍后编辑您的帖子之前的最初回答。它回答了如何在条件

A
下使用
B
中的值对C
M
进行索引。


您可以通过组合

torch.scatter
torch.where
来实现此目的:

我们从最小的设置开始:

>>> A = torch.rand(2,8)
>>> B = torch.randint(0, A.size(-1), size=(len(A), 5))
>>> C = torch.rand_like(B.float())
>>> M = torch.ones_like(A).bool()

按照您的示例“除了所有行的索引 3 和 5”,我们设置了

M[:,3]
M[:,5] = False
。然后我们可以使用
torch.scatter
执行索引(我们不能直接在
scatter_
上进行
A
,因为我们不知道特定索引是否有效。所以我们改为不恰当地执行:

>>> O = torch.zeros_like(A).scatter_(dim=1, index=B, src=C)

张量

O
充当缓冲区 ie. 就好像所有行都是有效的。上面的行相当于伪代码中的
O[j, B[b,j]] = C
。然后,您可以使用
A
 将其与基于掩码 M
 的初始张量 
torch.where
结合起来:

>>> O.where(M, A) # equivalent to torch.where(M, O, A)

0
投票

使用 PyTorch 使用张量索引:

import torch

A = torch.tensor([[0.6301,  0.2310, -1.1964, -1.1293, -1.0428,  0.4011,  0.0519, -3.0591],
                  [1.3691, -0.7477, -0.9323, -0.3670, -0.1568, -0.1282, -0.4557,  0.1747]])

B = torch.tensor([[0, 2, 7, 7, 5],
                  [2, 4, 7, 2, 5]])

output = A[torch.arange(B.size(0)).unsqueeze(1), B]

print(output)
© www.soinside.com 2019 - 2024. All rights reserved.