嗨,我正在尝试过滤以下 numpy 数组,但遇到了问题。我想过滤所有等于 [('a','b','c'),1] 的数组。问题是我不知道如何组合每个数组中的元素,而不是 [('a','b','c'),1],我会 [('a','b ','c',1)],或者简单地过滤给定的原始结构。我尝试了 np.concatenate() 和 np.ravel() 的组合,但结果出乎意料。
a = np.array([[('a','b','c'), 1], [('b','c','a'), 1], [('a','b','c'), 2], [('a','b','c'), 1]])
Method:
Filter if 1st element = 'a', 2nd element = 'b', 3rd element ='c' and 4th element = 1
Desired Output:
output = np.array([[('a','b','c'), 1], [('a','b','c'), 1]])
编辑:我能够让它与 pandas 解决方案一起使用,但只能通过将它转换为数据框,这太昂贵了,因此我试图用 numpy 实现更优化的解决方案
你可以这样做:
a[(a == [('a', 'b', 'c'), 1]).all(1)]
输出:
array([[('a', 'b', 'c'), 1],
[('a', 'b', 'c'), 1]], dtype=object)
a == [('a', 'b', 'c'), 1]
的输出:
array([[ True, True],
[False, True],
[ True, False],
[ True, True]])
在您的情况下,一种更快的方法是明确比较第一项和第二项:
[r for r in a if r[0] == ('a','b','c') and r[1] == 1]
In [400]: %timeit [r for r in a if r[0] == ('a','b','c') and r[1] == 1]
1.2 µs ± 6.69 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [402]: %timeit a[(a == [('a', 'b', 'c'), 1]).all(1)]
7.21 µs ± 85.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)