有没有办法使用numpy的数组索引对列执行多次检查?

问题描述 投票:1回答:1

我有一个2D数组数据,我正在尝试从这些数据中有效地修剪坏列。我正在尝试删除任何包含值0的列,它们在最小值和最大值之间的绝对差值大于12,或者包含大于9.5的值。

我工作的代码,但它很慢。根据我的理解,在后台,每个代码行都有一个循环遍历我的数组。我想知道是否有办法将其减少到一个循环。

import numpy as np

data_array = data_array[:,abs(data_array).min(0)!=0]
data_array = data_array[:,abs(data_array.min(0)-data_array.max(0)) < 12]
data_array = data_array[:,abs(data_array).max(0) < 9.5]
python arrays numpy
1个回答
0
投票

我认为不可能在一个循环中执行这三个检查。

您可以通过正确订购修剪操作来提高性能。实际上,您应该首先检查删除大多数列的条件,以便传递给第二个过滤器的数组尽可能小。相同的标准适用于其余的过滤器。

根据评论,您的数据范围从-3030。可以预期,最常见的无效列是包含大于9.5的值的列。我还猜测丢弃列的最常见原因是存在零值。如果这些假设不正确,您应该相应地更改过滤器的顺序。通过删除不必要的函数调用(例如abs)可以实现进一步的改进。

以下函数以不同的顺序实现相同的过滤操作,如上所述:

import numpy as np

def trim(x, low=0, high=9.5, diff=12):
    x = x[:, np.all(x != 0, axis=0)]
    x = x[:, np.ptp(x, axis=0) <= diff]
    x = x[:, np.all(x <= high, axis=0)]
    return x

def trim_reordered(x, low=0, high=9.5, diff=12):
    x = x[:, np.all(x <= high, axis=0)]
    x = x[:, np.ptp(x, axis=0) <= diff]
    x = x[:, np.all(x != 0, axis=0)]
    return x

Demo

In [205]: np.random.seed(213)

In [206]: small_arr = np.random.randint(low=-30, high=30, size=(3, 10))

In [207]: small_arr
Out[207]: 
array([[ 13,   6,   2, -29,  13,  11, -12, -24,   5,   9],
       [ 29,  24,  16, -21, -27,  -5,  -5, -16,  21, -29],
       [-10,  10, -24, -10,   4,   0,  -8, -23,   0,   4]])

In [208]: trim(small_arr)
Out[208]: 
array([[-12, -24],
       [ -5, -16],
       [ -8, -23]])

In [209]: large_arr = np.random.randint(low=-30, high=30, size=(10, 10**6))

In [210]: %timeit trim(large_arr)
77.3 ms ± 470 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [211]: %timeit trim_reordered(large_arr)
16.1 ms ± 174 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [212]: np.all(trim(large_arr) == trim_reordered(large_arr))
Out[212]: True
© www.soinside.com 2019 - 2024. All rights reserved.