Numpy 2d数组,获取指定列索引等于1的行的索引。

问题描述 投票:1回答:1

我有一个2d的numpy数组,像这样的数组永远只有0,1的值。

a = np.array([[1, 0, 1, 0],  # Indexes 0 and 2 == 1
             [0, 1, 1, 0],   # Indexes 1 and 2 == 1
             [0, 1, 0, 1],   # Indexes 1 and 3 == 1
             [0, 1, 1, 1]])  # Indexes 1, 2, and 3 == 1

我想做的是得到每一行的索引,其中传递的一对列索引都等于1。

例如,如果做这件事的函数是 get_rows, get_rows(a, [1, 3])同理,返回[2,3],因为在索引2和3处的行的列索引1和3等于1,同理。 get_rows(a, [1, 2]) 应该返回 [1, 3]。

我知道如何在Pandas的数据框架中做这件事,但我想坚持使用纯numpy来做这件事。 我试着用 np.where 以某种形式

np.where( ((a[i1 - 1] == 1) & (a[i2 - 1] == 1) ))

但这似乎并没有给我我想要的东西,而且对于不同数量的传入指数也不适用。

python arrays numpy
1个回答
4
投票

我想你要找的是这个。

col_idx = [1, 2]
np.where(a[:,col_idx].all(axis=1))[0]

你可以使用任何你想传递给它的列索引。它是相当自明的,它提取列,并使用np.where搜索所有1在那里的行。

编辑:根据@疯狂物理学家的建议,这里有另一个类似的解决方案。

np.flatnonzero(a[:,col_idx].all(axis=1))

输出例子给你的输入。

[1 3]

0
投票

解决方案

试试这个。

代码 - 逻辑

  • 找出给定列所在行和列的索引(target_col_index)有 1.
  • 只选择那些列索引匹配的行。target_col_index.
  • 通过检查哪些行索引的结果与列数相同,来缩小结果范围。target_col_index.
import numpy as np

target_col_index = [1,2]
target_row_index = get_row_index(a, target_col_index)
print(target_row_index)

## Output
# [1,3]

## Other cases tested
# test_col_indexes = [ [0,1], [0,2], [0,3], [1,2], [1,3], [2,3], [0,1,3], [1,2,3] ]
# returned_row_indexes = [ [], [0], [], [1,3], [2,3], [3], [], [3] ]

代码 - 自定义功能

def get_row_index(arr, target_col_index=None):
    if target_col_index is None:
        return None
    else:
        row_index, col_index = np.where(arr==1)
        result = row_index[np.isin(col_index, target_col_index)]
        rows, counts = np.unique(result, return_counts=True)
        target_row_index = rows[counts==len(target_col_index)]
        return target_row_index

虚数据

a = np.array([[1, 0, 1, 0],  # Indexes 0 and 2 == 1
            [0, 1, 1, 0],   # Indexes 1 and 2 == 1
            [0, 1, 0, 1],   # Indexes 1 and 3 == 1
            [0, 1, 1, 1]])  # Indexes 1, 2, and 3 == 1
© www.soinside.com 2019 - 2024. All rights reserved.