Pytorch 高级索引,以列表列表作为索引

问题描述 投票:0回答:1

这里有一些Python代码来重现我的问题:

import torch

n, m = 9, 4

x = torch.arange(0, n * m).reshape(n, m)
print(x.shape)
print(x)
# torch.Size([9, 4])
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11],
#         [12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23],
#         [24, 25, 26, 27],
#         [28, 29, 30, 31],
#         [32, 33, 34, 35]])

list_of_indices = [
    [],
    [2, 3],
    [1],
    [],
    [],
    [],
    [0, 1, 2, 3],
    [],
    [0, 3],
]
print(list_of_indices)

for i, indices in enumerate(list_of_indices):
    x[i, indices] = -1

print(x)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5, -1, -1],
#         [ 8, -1, 10, 11],
#         [12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23],
#         [-1, -1, -1, -1],
#         [28, 29, 30, 31],
#         [-1, 33, 34, -1]])

我有一个索引列表。我想使用

x
中的索引将
-1
中的索引设置为特定值(此处为
list_of_indices
)。在此列表中,每个子列表对应一行
x
,包含要设置为该行
-1
的索引。这可以使用 for 循环轻松完成,但我觉得 pytorch 可以更有效地做到这一点。

我尝试了以下方法:

x[torch.arange(len(list_of_indices)), list_of_indices] = -1

但结果是

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [9], [9, 0]

我试图找到有同样问题的人,但是有关索引张量的问题数量如此之多,以至于我可能错过了。

python indexing tensor torch
1个回答
0
投票

这是因为

list_of_indices
是一个参差不齐的
list
(即它包含空嵌套
[]
),所以如果我们包含一个返回
tensor
的函数,则与
shape
相同的
x
,其中
1
s 是来自
indices
list_of_indices
0
s 是不在
list_of_indices
中的索引),那么我们可以将其输入到
torch.where
索引中
x
:

def get_indices_from_list(list_of_indices):
    def fill_list(f):
        _f = torch.zeros(4).long(); _f[f] = 1
        return _f
    return torch.stack([fill_list(i) for i in list_of_indices])

x[torch.where(get_indices_from_list(list_of_indices) == 1)] = -1
print(x)

输出:

tensor([[ 0,  1,  2,  3],
        [ 4,  5, -1, -1],
        [ 8, -1, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23],
        [-1, -1, -1, -1],
        [28, 29, 30, 31],
        [-1, 33, 34, -1]])
© www.soinside.com 2019 - 2024. All rights reserved.