我有一个2d的numpy数组,像这样的数组永远只有0,1的值。
a = np.array([[1, 0, 1, 0], # Indexes 0 and 2 == 1
[0, 1, 1, 0], # Indexes 1 and 2 == 1
[0, 1, 0, 1], # Indexes 1 and 3 == 1
[0, 1, 1, 1]]) # Indexes 1, 2, and 3 == 1
我想做的是得到每一行的索引,其中传递的一对列索引都等于1。
例如,如果做这件事的函数是 get_rows
, get_rows(a, [1, 3])
同理,返回[2,3],因为在索引2和3处的行的列索引1和3等于1,同理。 get_rows(a, [1, 2])
应该返回 [1, 3]。
我知道如何在Pandas的数据框架中做这件事,但我想坚持使用纯numpy来做这件事。 我试着用 np.where
以某种形式
np.where( ((a[i1 - 1] == 1) & (a[i2 - 1] == 1) ))
但这似乎并没有给我我想要的东西,而且对于不同数量的传入指数也不适用。
我想你要找的是这个。
col_idx = [1, 2]
np.where(a[:,col_idx].all(axis=1))[0]
你可以使用任何你想传递给它的列索引。它是相当自明的,它提取列,并使用np.where搜索所有1在那里的行。
编辑:根据@疯狂物理学家的建议,这里有另一个类似的解决方案。
np.flatnonzero(a[:,col_idx].all(axis=1))
输出例子给你的输入。
[1 3]
试试这个。
target_col_index
)有 1
.target_col_index
.target_col_index
.import numpy as np
target_col_index = [1,2]
target_row_index = get_row_index(a, target_col_index)
print(target_row_index)
## Output
# [1,3]
## Other cases tested
# test_col_indexes = [ [0,1], [0,2], [0,3], [1,2], [1,3], [2,3], [0,1,3], [1,2,3] ]
# returned_row_indexes = [ [], [0], [], [1,3], [2,3], [3], [], [3] ]
def get_row_index(arr, target_col_index=None):
if target_col_index is None:
return None
else:
row_index, col_index = np.where(arr==1)
result = row_index[np.isin(col_index, target_col_index)]
rows, counts = np.unique(result, return_counts=True)
target_row_index = rows[counts==len(target_col_index)]
return target_row_index
a = np.array([[1, 0, 1, 0], # Indexes 0 and 2 == 1
[0, 1, 1, 0], # Indexes 1 and 2 == 1
[0, 1, 0, 1], # Indexes 1 and 3 == 1
[0, 1, 1, 1]]) # Indexes 1, 2, and 3 == 1