在pytorch中,有没有一种内置的方法来提取给定索引的行?

问题描述 投票:0回答:1

假设我有一个 torch tensor

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])

和一份名单

b = [0,2]

有没有一种内置的方法来提取0和2行,并将它们放入一个新的张量中。

tensor([[1,2,3],
        [7,8,9]])

特别是,有没有一个函数是这样的:

extract_rows(a,b) -> c

其中 c 包含所需的行。当然,这可以通过for循环来完成,但一般来说,内置的方法更快。

请注意,这个例子只是一个例子,列表中可能有几十个索引,张量中可能有几百行。

python pytorch tensor
1个回答
1
投票

看看Torch的内置方法 index_select() 方法。这将会对你有所帮助.或者你可以使用slicing来完成。

tensor = [[1,2,3],
            [4,5,6],
            [7,8,9]]

new_tensor = tensor[0::2]
print(new_tensor)

输出。

[[1, 2, 3], [7, 8, 9]]

0
投票

简单地 a[b] 行得通

import torch
a = torch.tensor([[1,2,3],
                  [4,5,6],
                  [7,8,9]])
b = [0,2]
a[b]
tensor([[1, 2, 3],
        [7, 8, 9]])
© www.soinside.com 2019 - 2024. All rights reserved.