我有一个 3-D 张量
x_k
。我还有另外两个 3-D 张量 h_att
和 i_att
,它们用作索引。我想知道q = x_k[h_att,i_att]
的详细工作原理
import torch
import numpy as np
import torch.nn as nn
x_k = torch.arange(1,25).reshape(2,4,3)
x_k
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]],
[[13, 14, 15],
[16, 17, 18],
[19, 20, 21],
[22, 23, 24]]])
i_att = torch.tensor([[[0,1],[1,2],[2,3],[3,0]]])
i_att
tensor([[[0, 1],
[1, 2],
[2, 3],
[3, 0]]])
h_att = torch.arange(2).reshape(2, 1, 1).long()
h_att
tensor([[[0]],
[[1]]])
q = x_k[h_att,i_att]
q
tensor([[[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 4, 5, 6],
[ 7, 8, 9]],
[[ 7, 8, 9],
[10, 11, 12]],
[[10, 11, 12],
[ 1, 2, 3]]],
[[[13, 14, 15],
[16, 17, 18]],
[[16, 17, 18],
[19, 20, 21]],
[[19, 20, 21],
[22, 23, 24]],
[[22, 23, 24],
[13, 14, 15]]]])
为什么输出维度是(2,4,2,3)?
此外,如果我将
i_att
更改为二维张量 i_att = torch.tensor([[0,1],[1,2],[2,3],[3,0]])
,结果与之前的 q 相同。为什么?
在代码片段中,x_k 是形状为 (2, 4, 3) 的 3-D 张量。 h_att 也是形状为 (2, 1, 1) 的 3-D 张量,i_att 是形状为 (1, 4, 2) 的 3-D 张量。
当我们使用索引操作 x_k[h_att, i_att] 时,意味着我们要从 x_k 中选择 h_att 和 i_att 中索引指定位置的值。生成的张量将具有与 i_att 相同的形状,并具有大小为 3 的附加维度。
让我们详细看看它是如何工作的。 h_att的第一个元素是0,i_att的第一个元素是[0, 1]。这意味着我们要从 x_k 中选择位置 (0, 0, 0) 的元素和位置 (0, 1, 1) 的元素。这两个元素分别是1和5。
h_att的下一个元素是1,i_att的下一个元素是[1, 2]。这意味着我们要从 x_k 中选择位置 (1, 1, 0) 的元素和位置 (1, 2, 1) 的元素。这两个元素分别是16和20。
我们对 h_att 和 i_att 的所有元素重复这个过程,并将结果收集在形状为 (2, 4, 2, 3) 的张量中。前两个维度对应h_att和i_att的形状,后面两个维度对应i_att的大小
如果我们将 i_att 更改为形状为 (4, 2) 的二维张量,结果将是相同的,因为索引操作仍然按元素应用于两个张量。换句话说,h_att和i_att的每个元素仍然用于从x_k中选择一个对应的元素。生成的张量将具有与 i_att 相同的形状,即 (4, 2),具有大小为 3 的附加维度。