我有一个大小为mask
的torch.Size([20, 1, 199])
和张量,reconstruct_output
和inputs
都大小为torch.Size([20, 1, 161, 199])
。
我想将reconstruct_output
设置为inputs
,其中mask
为0
。我尝试过:
reconstruct_output[mask == 0] = inputs[mask == 0]
但出现错误:
IndexError: The shape of the mask [20, 1, 199] at index 2 does not match the shape of the indexed tensor [20, 1, 161, 199] at index 2
我们可以在这里使用advanced indexing
。为了获得我们要用于索引advanced indexing
和reconstruct_output
的索引数组,我们需要沿inputs
的轴进行索引。为此,我们可以使用m==0
,并使用结果索引将np.where
更新为: