如何在pytorch中动态索引张量?

问题描述 投票:2回答:2

例如,我得到了一个张量:

tensor = torch.rand(12, 512, 768)

我得到了一个索引列表,说它是:

[0,2,3,400,5,32,7,8,321,107,100,511]

我希望在给定索引列表的情况下从维度2上的512个元素中选择1个元素。然后张量的大小将成为(12, 1, 768)

有办法吗?

python deep-learning pytorch torch tensor
2个回答
3
投票

还有一种方法只使用PyTorch并使用索引和torch.split避免循环:

tensor = torch.rand(12, 512, 768)

# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list) 

# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)

当你打电话给tensor[:, idx_tensor, :]时,你会得到一个形状的张量: (12, len_of_idx_list, 768)。 第二个维度取决于您的索引数量。

使用torch.split这个张量被分成一个形状张量列表:(12, 1, 768)

所以最后list_of_tensors包含形状的张量:

[torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768])]

0
投票

是的,您可以使用索引直接对其进行切片,然后使用torch.unsqueeze()将2D张量提升为3D:

# inputs
In [6]: tensor = torch.rand(12, 512, 768)
In [7]: idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]

# slice using the index and then put a singleton dimension along axis 1
In [8]: for idx in idx_list:
   ...:     sampled_tensor = torch.unsqueeze(tensor[:, idx, :], 1)
   ...:     print(sampled_tensor.shape)
   ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])

或者,如果您想要更简洁的代码并且不想使用torch.unsqueeze(),那么使用:

In [11]: for idx in idx_list:
    ...:     sampled_tensor = tensor[:, [idx], :]
    ...:     print(sampled_tensor.shape)
    ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])

注意:如果你只想从for切换一个idx,就没有必要使用idx_list循环

© www.soinside.com 2019 - 2024. All rights reserved.