鉴于:
我想获取 A 在索引 B 处的值。 例如:
A = [[ 0.6301, 0.2310, -1.1964, -1.1293, -1.0428, 0.4011, 0.0519, -3.0591],
[ 1.3691, -0.7477, -0.9323, -0.3670, -0.1568, -0.1282, -0.4557, 0.1747]])
B = [[0, 2, 7, 7, 5],
[2, 4, 7, 2, 5]])
output = [[ 0.6301, -1.1964, -3.0591, -3.0591, 0.4011],
[-0.9323, -0.1568, 0.1747, -0.9323, -0.1282]]
谢谢你
使用 PyTorch 使用张量索引:
import torch
A = torch.tensor([[0.6301, 0.2310, -1.1964, -1.1293, -1.0428, 0.4011, 0.0519, -3.0591],
[1.3691, -0.7477, -0.9323, -0.3670, -0.1568, -0.1282, -0.4557, 0.1747]])
B = torch.tensor([[0, 2, 7, 7, 5],
[2, 4, 7, 2, 5]])
output = A[torch.arange(B.size(0)).unsqueeze(1), B]
print(output)