给定
import torch a: torch.Tensor b: torch.Tensor assert a.shape[1:] == b.shape[1:] idx = torch.randint(b.shape[0], [a.shape[0]])
我想做
b[...] = a[idx]
但没有由
a[idx]
idx
您可以使用
torch.index_select
torch.index_select(a, 0, idx, out = b)