如果最小值满足条件,则获取2D np.array中每行的最小值索引

问题描述 投票:2回答:2

我有一个尺寸为np.array的2D 1000 (rows) x 12 (columns)

我需要得到低于1.5的那些值的指数。 如果一行包含多个满足此条件的值,那么我只需要保留最低的索引。

我对使用idx1,idx2=np.where(x < 1.5)非常满意,但这有时会返回同一行中的几个索引。我当然可以在idx1中遍历所有重复的行,并且只保留x中值最低的索引,但我想知道是否有更多的pythonic方式。

谢谢

python arrays numpy unique where
2个回答
0
投票

你可以这样做:

# First index is all rows
idx1 = np.arange(len(x))
# Second index is minimum values
idx2 = np.argmin(m, axis=1)
# Filter rows where minimum is not below threshold
valid = x[idx1, idx2] < 1.5
idx1 = idx1[valid]
idx2 = idx2[valid]

0
投票

一种方法是使用numpy masked array。让我们定义以下随机ndarray

a = np.random.normal(1,2,(4,2))

print(a.round(2))
array([[ 1.41, -0.68],
       [-1.53,  2.74],
       [ 1.19,  2.66],
       [ 2.  ,  1.26]])

我们可以用以下方法定义一个蒙版数组:

ma = np.ma.array(a, mask = a >= 1.5)

print(ma.round(2))
[[1.41 -0.68]
 [-1.53 --]
 [1.19 --]
 [-- 1.26]]

为了处理没有低于阈值的值的行,您可以执行以下操作:

m = ma.mask.any(axis=1)
# array([ True,  True,  True,  True])

如果给定行中没有有效值,则将包含False。然后将np.argmin放在蒙面数组上,以获得最小值低于1.5的列:

np.argmin(ma, axis=1)[m]
# array([1, 0, 0, 1])

对于你可以做的行:

np.flatnonzero(m)
# array([0, 1, 2, 3])
© www.soinside.com 2019 - 2024. All rights reserved.