多维张量的前 K 个索引

问题描述 投票:0回答:4

我有一个 2D 张量,我想获取前 k 个值的索引。我了解 pytorch 的 topk 函数。 pytorch 的 topk 函数的问题是,它计算某个维度上的 topk 值。我想获得两个维度上的 topk 值。

例如对于以下张量

a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])

pytorch 的 topk 函数会给我以下内容。

values, indices = torch.topk(a, 3)

print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])

但我想得到以下

tensor([[0, 1],
        [2, 0],
        [3, 1]])

这是 2D 张量中 9 的索引。

有什么方法可以使用 pytorch 来实现这一点吗?

python pytorch tensor matrix-indexing
4个回答
12
投票
v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)

输出:

[[3 1]
 [2 0]
 [0 1]]
  1. 压平并找到顶部 k
  2. 使用
    unravel_index
  3. 将 1D 索引转换为 2D

1
投票

您可以

flatten
原始张量,应用
topk
,然后将结果标量索引转换回多维索引,如下所示:

def descalarization(idx, shape):
    res = []
    N = np.prod(shape)
    for n in shape:
        N //= n
        res.append(idx // N)
        idx %= N
    return tuple(res)

示例:

torch.tensor([descalarization(k, a.size()) for k in torch.topk(a.flatten(), 5).indices])
# Returns 
# tensor([[3, 1],
#         [2, 0],
#         [0, 1],
#         [3, 4],
#         [2, 4]])

0
投票

您可以根据自己的需要进行一些向量运算来进行过滤。在这种情况下不使用 topk。

print(a)
tensor([[4, 9, 7, 4, 0],
    [8, 1, 3, 1, 0],
    [9, 8, 4, 4, 8],
    [0, 9, 4, 7, 8],
    [8, 8, 0, 1, 4]])

values, indices = torch.max(a,1)   # get max values, indices
temp= torch.zeros_like(values)     # temporary
temp[values==9]=1                  # fill temp where values are 9 (wished value)
seq=torch.arange(values.shape[0])  # create a helper sequence
new_seq=seq[temp>0]                # filter sequence where values are 9
new_temp=indices[new_seq]          # filter indices with sequence where values are 9
final = torch.stack([new_seq, new_temp], dim=1)  # stack both to get result

print(final)
tensor([[0, 1],
        [2, 0],
        [3, 1]])

0
投票

PyTorch 2.2 及以上版本开始

torch.unravel_index
现在已成为库的一部分,因此不再需要转换为 @mujjiga 的答案引用的 NumPy。因此,借用该答案并在 PyTorch 中完全执行此操作:

v, i = torch.topk(a.flatten(), 3)
indices = torch.column_stack(torch.unravel_index(i, a.shape))
print(indices)

我们得到:

tensor([[2, 0],
        [0, 1],
        [3, 1]])
© www.soinside.com 2019 - 2024. All rights reserved.