如何通过索引张量获取子集张量(Pytorch)

问题描述 投票:0回答:1

鉴于:

  • 张量 A 的形状为:(batch_size, vocab_size)
  • 张量 B 是 token 索引的张量,其形状为:(batch_size, seq_len

我想获取 A 在索引 B 处的值。 例如:

A = [[ 0.6301,  0.2310, -1.1964, -1.1293, -1.0428,  0.4011,  0.0519, -3.0591],
     [ 1.3691, -0.7477, -0.9323, -0.3670, -0.1568, -0.1282, -0.4557,  0.1747]])
B = [[0, 2, 7, 7, 5],
     [2, 4, 7, 2, 5]])

output = [[ 0.6301, -1.1964, -3.0591, -3.0591,  0.4011],
          [-0.9323, -0.1568,  0.1747, -0.9323, -0.1282]]

谢谢你

python pytorch tensor
1个回答
0
投票

使用 PyTorch 使用张量索引:

import torch

A = torch.tensor([[0.6301,  0.2310, -1.1964, -1.1293, -1.0428,  0.4011,  0.0519, -3.0591],
                  [1.3691, -0.7477, -0.9323, -0.3670, -0.1568, -0.1282, -0.4557,  0.1747]])

B = torch.tensor([[0, 2, 7, 7, 5],
                  [2, 4, 7, 2, 5]])

output = A[torch.arange(B.size(0)).unsqueeze(1), B]

print(output)
© www.soinside.com 2019 - 2024. All rights reserved.