我正在使用 pytorch,我想在不使用循环的情况下将简单的 torch.where(array > 0) 应用于一批数组,我如何使用 torch 函数来执行此代码?
def batch_node_indices(states_batch):
batch_indices = []
for state in states_batch:
node_indices = torch.where(state > 0)[0].detach().cpu().numpy()
batch_indices.append(node_indices)
return batch_indices
我尝试了不同的手电筒功能,但没有成功。我希望该方法返回一批数组,每个数组包含状态数组大于 0 的索引。
“一批索引”到底是什么意思?
类似
torch.where(condition)
的问题是批次中的每个项目都有不同数量的元素,其中 condition=True
。这意味着您不能批量应用 where
,因为批次中每个项目的输出大小都不同。
where
的默认行为是输出一组张量,每个轴一个,显示condition=True
所在的所有索引元组。输出指数被扁平化以处理不规则大小的问题。如果需要,您可以使用输出来批量获取索引。
x = torch.randn(16, 32, 64)
indices = torch.where(x>0)
print(indices)
> (tensor([ 0, 0, 0, ..., 15, 15, 15]),
> tensor([ 0, 0, 0, ..., 31, 31, 31]),
> tensor([ 4, 5, 7, ..., 61, 62, 63]))
index_tensor = torch.stack(indices)
# for example, select outputs from the first item in the batch
index_tensor[:, index_tensor[0] == 0]
您还可以使用
torch.where
中的附加参数来返回与输入形状相同的张量。例如
x = torch.randn(16, 32, 64)
x1 = torch.where(x>0, 1, 0) # fills 1 where x>0, 0 elsewhere
# shape is retained
x.shape == x1.shape
> True
x2 = torch.where(x>0, x, float('-inf')) # returns `x` with -inf where x<0