如何将给定索引处的一个 pytorch 张量的元素复制到另一个张量中,而无需中间分配或循环

问题描述 投票:0回答:1

给定

import torch

a: torch.Tensor
b: torch.Tensor
assert a.shape[1:] == b.shape[1:]
idx = torch.randint(b.shape[0], [a.shape[0]])

我想做

b[...] = a[idx]

但没有由

a[idx]
产生的中间缓冲区或循环
idx
。我该怎么做?

python pytorch slice tensor
1个回答
0
投票

您可以使用

torch.index_select

torch.index_select(a, 0, idx, out = b)
© www.soinside.com 2019 - 2024. All rights reserved.