过滤嵌套的 numpy 数组

问题描述 投票:0回答:2

嗨,我正在尝试过滤以下 numpy 数组,但遇到了问题。我想过滤所有等于 [('a','b','c'),1] 的数组。问题是我不知道如何组合每个数组中的元素,而不是 [('a','b','c'),1],我会 [('a','b ','c',1)],或者简单地过滤给定的原始结构。我尝试了 np.concatenate() 和 np.ravel() 的组合,但结果出乎意料。

a = np.array([[('a','b','c'), 1], [('b','c','a'), 1], [('a','b','c'), 2], [('a','b','c'), 1]])

Method:
Filter if 1st element = 'a', 2nd element = 'b', 3rd element ='c' and 4th element = 1

Desired Output:
output = np.array([[('a','b','c'), 1], [('a','b','c'), 1]])

编辑:我能够让它与 pandas 解决方案一起使用,但只能通过将它转换为数据框,这太昂贵了,因此我试图用 numpy 实现更优化的解决方案

python numpy
2个回答
1
投票

你可以这样做:

a[(a == [('a', 'b', 'c'), 1]).all(1)]

输出:

array([[('a', 'b', 'c'), 1],
       [('a', 'b', 'c'), 1]], dtype=object)

a == [('a', 'b', 'c'), 1]
的输出:

array([[ True,  True],
       [False,  True],
       [ True, False],
       [ True,  True]])

0
投票

在您的情况下,一种更快的方法是明确比较第一项和第二项:

[r for r in a if r[0] == ('a','b','c') and r[1] == 1]

In [400]: %timeit [r for r in a if r[0] == ('a','b','c') and r[1] == 1]
1.2 µs ± 6.69 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [402]: %timeit a[(a == [('a', 'b', 'c'), 1]).all(1)]
7.21 µs ± 85.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.