使用numpy.where来防止越界

问题描述 投票:1回答:2

我试图根据索引数组查找数组中的值。这个索引数组可以包含可能超出范围的索引。在这种情况下,我想返回一个特定的值(这里是0)。

(我可以使用for循环,但这太慢了。)

所以我这样做:

data = np.arange(1000).reshape(10, 10, 10)
i = np.arange(9).reshape(3, 3)
i[0, 0] = 10
condition = (i[:, 0] < 10) & (i[:, 1] < 10) & (i[:, 2] < 10)
values = np.where(condition, data[i[:, 0], i[:, 1], i[:, 2]], 0)

但是我仍然得到一个出界的错误:

IndexError: index 10 is out of bounds for axis 0 with size 10

我想这是因为第二个参数没有延迟评估,并且在函数调用之前进行了求值。

numpy是否有基于条件访问数组但仍保留订单的解决方案?通过保留顺序,我的意思是我不能首先过滤数组,因为我可以在最终结果中松开顺序。最后,在那个特定的例子中,当索引超出范围时,我仍然希望values数组包含0。所以最终的结果是:

array([ 0, 345, 678])
python arrays numpy where
2个回答
1
投票

索引数组的每一列都存储每个维度的索引。我们可以生成有效的掩码(通过边界)并将其中的无效掩码设置为零。即,界外情况将被[0,0,0]索引,然后让数组被这个修改版本索引,最后再次使用掩码重置无效的,如此 -

shp = data.shape
valid_mask = (i < shp).all(1)
i[~valid_mask] = 0
out = np.where(valid_mask,data[tuple(i.T)],0)

不改变i的修改后的紧凑版本将是 -

out = np.where(valid_mask,data[tuple(np.where(valid_mask,i.T,0))],0)

1
投票

第一个索引,然后应用修正来更正值。

shp = np.array(data.shape)
j = i % shp 
res = data[j.T.tolist()]
res[(i >= shp).nonzero()[0]] = 0

print(res)
array([  0, 345, 678])
© www.soinside.com 2019 - 2024. All rights reserved.