假设我有一个NumPy数组arr
,我想按元素进行过滤,例如我只想获取低于特定阈值k
的值。
有几种方法,例如:
np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
arr[arr < k]
np.where()
:arr[np.where(arr < k)]
np.nonzero()
:arr[np.nonzero(arr < k)]
哪个最快?内存效率如何?
def filter_gen_np(arr, k):
return np.fromiter((x for x in arr if x < k), dtype=arr.dtype)
np.where()
:def filter_where(arr, k):
return arr[np.where(arr < k)]
def filter_mask(arr, k):
return arr[arr < k]
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
cimport numpy as cnp
cimport cython as ccy
import numpy as np
import cython as cy
cdef size_t _filter_cy(long[:] arr, long[:] result, size_t size, long k):
cdef size_t j = 0
for i in range(size):
if arr[i] < k:
result[j] = arr[i]
j += 1
return j
cpdef filter_cy(arr, k):
result = np.empty_like(arr)
new_size = _filter_cy(arr, result, arr.size, k)
return result[:new_size]
@nb.jit
def filter_np_nb(arr, k):
result = np.empty_like(arr)
j = 0
for i in range(arr.size):
if arr[i] < k:
result[j] = arr[i]
j += 1
return result[:j]
方法1比其他方法慢得多(大约2个数量级,因此在图表中将其省略)。>>
时间将取决于输入数组的大小和已过滤项目的百分比。
第一张图将时序作为输入大小的函数(对于〜50%滤出的元素):
通常,基于Numba的方法是最快的,其次是Cython方法。在NumPy中,基于np.where()
的方法是最快的,除了很小的输入(约100个元素以下)(布尔掩码切片更快)之外。
第二张图根据通过过滤器的项目来确定时序(对于大约100万个元素的固定输入大小:]
[第一个观察结果是,所有方法在达到〜50%填充时最慢,而填充更少或更多时,它们则更快,朝着不填充的方向最快(滤出值的最高百分比,通过值的最低百分比为表示在图表的x轴上)。同样,Numba和Cython版本均明显比基于NumPy的版本快,其中Numba几乎总是最快,而Cython在图表的最右端胜过Numba。在基于NumPy的解决方案中观察到类似的情况,np.where()
几乎总是优于布尔蒙版切片,除了图的最右边部分。
(提供完整代码here)
基于生成器的方法仅需要最少的临时存储,而与输入的大小无关。在内存方面,这是最有效的方法。
在内存方面,Cython和Numba都需要输入大小的临时数组。因此,这些是内存效率最低的方法。
布尔型掩码切片解决方案需要一个输入大小但类型为bool
的临时数组,该数组在NumPy中为1位,因此比典型的64位上NumPy数组的默认大小小约64倍。位系统。
基于np.where()
的解决方案与第一步(在np.where()
内)的布尔蒙版切片具有相同的要求,后者会转换为一系列的int
s(通常在64-but系统上为int64
)在第二步(np.where()
的输出)。因此,此第二步具有可变的内存要求,具体取决于已过滤元素的数量。
arr = np.arange(100)
k = 50
print('`arr[arr > k]` is a copy: ', arr[arr > k].base is None)
# `arr[arr > k]` is a copy: True
print('`arr[np.where(arr > k)]` is a copy: ', arr[np.where(arr > k)].base is None)
# `arr[np.where(arr > k)]` is a copy: True
print('`arr[:k]` is a copy: ', arr[:k].base is None)
# `arr[:k]` is a copy: False