我有一个尺寸为np.array
的2D 1000 (rows) x 12 (columns)
。
我需要得到低于1.5
的那些值的指数。
如果一行包含多个满足此条件的值,那么我只需要保留最低的索引。
我对使用idx1,idx2=np.where(x < 1.5)
非常满意,但这有时会返回同一行中的几个索引。我当然可以在idx1
中遍历所有重复的行,并且只保留x
中值最低的索引,但我想知道是否有更多的pythonic方式。
谢谢
你可以这样做:
# First index is all rows
idx1 = np.arange(len(x))
# Second index is minimum values
idx2 = np.argmin(m, axis=1)
# Filter rows where minimum is not below threshold
valid = x[idx1, idx2] < 1.5
idx1 = idx1[valid]
idx2 = idx2[valid]
一种方法是使用numpy masked array。让我们定义以下随机ndarray
:
a = np.random.normal(1,2,(4,2))
print(a.round(2))
array([[ 1.41, -0.68],
[-1.53, 2.74],
[ 1.19, 2.66],
[ 2. , 1.26]])
我们可以用以下方法定义一个蒙版数组:
ma = np.ma.array(a, mask = a >= 1.5)
print(ma.round(2))
[[1.41 -0.68]
[-1.53 --]
[1.19 --]
[-- 1.26]]
为了处理没有低于阈值的值的行,您可以执行以下操作:
m = ma.mask.any(axis=1)
# array([ True, True, True, True])
如果给定行中没有有效值,则将包含False
。然后将np.argmin
放在蒙面数组上,以获得最小值低于1.5的列:
np.argmin(ma, axis=1)[m]
# array([1, 0, 0, 1])
对于你可以做的行:
np.flatnonzero(m)
# array([0, 1, 2, 3])