一个 3D 张量如何被两个 3D 张量索引?

问题描述 投票:0回答:1

我有一个 3-D 张量

x_k
。我还有另外两个 3-D 张量
h_att
i_att
,它们用作索引。我想知道
q = x_k[h_att,i_att]
的详细工作原理

import torch
import numpy as np
import torch.nn as nn
x_k = torch.arange(1,25).reshape(2,4,3)
x_k
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18],
         [19, 20, 21],
         [22, 23, 24]]])
i_att = torch.tensor([[[0,1],[1,2],[2,3],[3,0]]])
i_att
tensor([[[0, 1],
         [1, 2],
         [2, 3],
         [3, 0]]])
h_att = torch.arange(2).reshape(2, 1, 1).long()
h_att
tensor([[[0]],

        [[1]]])
q = x_k[h_att,i_att]
q
tensor([[[[ 1,  2,  3],
          [ 4,  5,  6]],

         [[ 4,  5,  6],
          [ 7,  8,  9]],

         [[ 7,  8,  9],
          [10, 11, 12]],

         [[10, 11, 12],
          [ 1,  2,  3]]],


        [[[13, 14, 15],
          [16, 17, 18]],

         [[16, 17, 18],
          [19, 20, 21]],

         [[19, 20, 21],
          [22, 23, 24]],

         [[22, 23, 24],
          [13, 14, 15]]]])

为什么输出维度是(2,4,2,3)?

此外,如果我将

i_att
更改为二维张量
i_att = torch.tensor([[0,1],[1,2],[2,3],[3,0]])
,结果与之前的 q 相同。为什么?

indexing pytorch tensor
1个回答
0
投票

在代码片段中,x_k 是形状为 (2, 4, 3) 的 3-D 张量。 h_att 也是形状为 (2, 1, 1) 的 3-D 张量,i_att 是形状为 (1, 4, 2) 的 3-D 张量。

当我们使用索引操作 x_k[h_att, i_att] 时,意味着我们要从 x_k 中选择 h_att 和 i_att 中索引指定位置的值。生成的张量将具有与 i_att 相同的形状,并具有大小为 3 的附加维度。

让我们详细看看它是如何工作的。 h_att的第一个元素是0,i_att的第一个元素是[0, 1]。这意味着我们要从 x_k 中选择位置 (0, 0, 0) 的元素和位置 (0, 1, 1) 的元素。这两个元素分别是1和5。

h_att的下一个元素是1,i_att的下一个元素是[1, 2]。这意味着我们要从 x_k 中选择位置 (1, 1, 0) 的元素和位置 (1, 2, 1) 的元素。这两个元素分别是16和20。

我们对 h_att 和 i_att 的所有元素重复这个过程,并将结果收集在形状为 (2, 4, 2, 3) 的张量中。前两个维度对应h_att和i_att的形状,后面两个维度对应i_att的大小

如果我们将 i_att 更改为形状为 (4, 2) 的二维张量,结果将是相同的,因为索引操作仍然按元素应用于两个张量。换句话说,h_att和i_att的每个元素仍然用于从x_k中选择一个对应的元素。生成的张量将具有与 i_att 相同的形状,即 (4, 2),具有大小为 3 的附加维度。

© www.soinside.com 2019 - 2024. All rights reserved.