使用 torch.where() 处理一批数组

问题描述 投票:0回答:1

我正在使用 pytorch,我想在不使用循环的情况下将简单的 torch.where(array > 0) 应用于一批数组,我如何使用 torch 函数来执行此代码?

def batch_node_indices(states_batch):
    batch_indices = []
    for state in states_batch:
        node_indices = torch.where(state > 0)[0].detach().cpu().numpy()
        batch_indices.append(node_indices)
    return batch_indices

我尝试了不同的手电筒功能,但没有成功。我希望该方法返回一批数组,每个数组包含状态数组大于 0 的索引。

python pytorch
1个回答
0
投票

“一批索引”到底是什么意思?

类似

torch.where(condition)
的问题是批次中的每个项目都有不同数量的元素,其中
condition=True
。这意味着您不能批量应用
where
,因为批次中每个项目的输出大小都不同。

where
的默认行为是输出一组张量,每个轴一个,显示
condition=True
所在的所有索引元组。输出指数被扁平化以处理不规则大小的问题。如果需要,您可以使用输出来批量获取索引。

x = torch.randn(16, 32, 64)
indices = torch.where(x>0)
print(indices)
> (tensor([ 0,  0,  0,  ..., 15, 15, 15]),
>  tensor([ 0,  0,  0,  ..., 31, 31, 31]),
>  tensor([ 4,  5,  7,  ..., 61, 62, 63]))

index_tensor = torch.stack(indices)
# for example, select outputs from the first item in the batch
index_tensor[:, index_tensor[0] == 0]

您还可以使用

torch.where
中的附加参数来返回与输入形状相同的张量。例如

x = torch.randn(16, 32, 64)
x1 = torch.where(x>0, 1, 0) # fills 1 where x>0, 0 elsewhere

# shape is retained 
x.shape == x1.shape
> True

x2 = torch.where(x>0, x, float('-inf')) # returns `x` with -inf where x<0
© www.soinside.com 2019 - 2024. All rights reserved.