我最近编写了一个脚本,用于将 [0, 1] 浮点数的 BGR 数组转换为 HSL 并返回。我将其发布在Code Review上。目前有一个答案,但它不会提高性能。
我已经将我的代码与
cv2.cvtColor
进行了基准测试,发现我的代码效率低下,所以我想用Numba编译代码以使其运行得更快。
我尝试用
@nb.njit(cache=True, fastmath=True)
包装每个函数,但这不起作用。
因此我测试了我单独使用过的每个 NumPy 语法和 NumPy 函数,并发现了两个不能与 Numba 一起使用的函数。
我需要找到每个像素的最大通道 (
np.max(img, axis=-1)
) 和每个像素的最小通道 (np.max(img, axis=-1)
),并且 axis
参数不适用于 Numba。
我尝试在谷歌上搜索这个,但我发现的唯一远程相关的是this,但它只实现了
np.any
和np.all
,并且仅适用于二维数组,而这里的数组是三维的。
我可以编写一个基于 for 循环的解决方案,但我不会编写它,因为它必然效率低下,并且违背了使用 NumPy 和 Numba 的初衷。
最小可重现示例:
import numba as nb
import numpy as np
@nb.njit(cache=True, fastmath=True)
def max_per_cell(arr):
return np.max(arr, axis=-1)
@nb.njit(cache=True, fastmath=True)
def min_per_cell(arr):
return np.min(arr, axis=-1)
img = np.random.random((3, 4, 3))
max_per_cell(img)
min_per_cell(img)
例外:
In [2]: max_per_cell(img)
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Cell In[2], line 1
----> 1 max_per_cell(img)
File C:\Python310\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File C:\Python310\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function amax at 0x0000014E306D3370>) found for signature:
>>> amax(array(float64, 3d, C), axis=Literal[int](-1))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'npy_max': File: numba\np\arraymath.py: Line 541.
With argument(s): '(array(float64, 3d, C), axis=int64)':
Rejected as the implementation raised a specific error:
TypingError: got an unexpected keyword argument 'axis'
raised from C:\Python310\lib\site-packages\numba\core\typing\templates.py:784
During: resolving callee type: Function(<function amax at 0x0000014E306D3370>)
During: typing of call at <ipython-input-1-b3894b8b12b8> (10)
File "<ipython-input-1-b3894b8b12b8>", line 10:
def max_per_cell(arr):
return np.max(arr, axis=-1)
^
如何解决这个问题?
在没有
np.max()
的情况下实现这个相当简单,而是使用循环:
@nb.njit()
def max_per_cell_nb(arr):
ret = np.empty(arr.shape[:-1], dtype=arr.dtype)
n, m = ret.shape
for i in range(n):
for j in range(m):
max_ = arr[i, j, 0]
max_ = max(max_, arr[i, j, 1])
max_ = max(max_, arr[i, j, 2])
ret[i, j] = max_
return ret
对其进行基准测试,结果比
np.max(arr, axis=-1)
快约 16 倍。
%timeit max_per_cell_nb(img)
4.88 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit max_per_cell(img)
81 ms ± 654 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
在对此进行基准测试时,我做出了以下假设:
np.max()
的速度更快,只需要15ms。请参阅检查 numpy 数组是否连续?了解如何判断数组是 C 顺序还是 Fortran 顺序。您的 np.random.random((3, 4, 3))
示例是 C 连续的。np.max(arr, axis=-1)
进行比较,因为它无法真正优化对 NumPy 函数的单个调用。