如何使用 Numba 查找 3 维数组中每个单元格的极值?

问题描述 投票:0回答:1

我最近编写了一个脚本,用于将 [0, 1] 浮点数的 BGR 数组转换为 HSL 并返回。我将其发布在Code Review上。目前有一个答案,但它不会提高性能。

我已经将我的代码与

cv2.cvtColor
进行了基准测试,发现我的代码效率低下,所以我想用Numba编译代码以使其运行得更快。

我尝试用

@nb.njit(cache=True, fastmath=True)
包装每个函数,但这不起作用。

因此我测试了我单独使用过的每个 NumPy 语法和 NumPy 函数,并发现了两个不能与 Numba 一起使用的函数。

我需要找到每个像素的最大通道 (

np.max(img, axis=-1)
) 和每个像素的最小通道 (
np.max(img, axis=-1)
),并且
axis
参数不适用于 Numba。

我尝试在谷歌上搜索这个,但我发现的唯一远程相关的是this,但它只实现了

np.any
np.all
,并且仅适用于二维数组,而这里的数组是三维的。

我可以编写一个基于 for 循环的解决方案,但我不会编写它,因为它必然效率低下,并且违背了使用 NumPy 和 Numba 的初衷。

最小可重现示例:

import numba as nb
import numpy as np

@nb.njit(cache=True, fastmath=True)
def max_per_cell(arr):
    return np.max(arr, axis=-1)

@nb.njit(cache=True, fastmath=True)
def min_per_cell(arr):
    return np.min(arr, axis=-1)

img = np.random.random((3, 4, 3))
max_per_cell(img)
min_per_cell(img)

例外:

In [2]: max_per_cell(img)
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 max_per_cell(img)

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function amax at 0x0000014E306D3370>) found for signature:

 >>> amax(array(float64, 3d, C), axis=Literal[int](-1))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'npy_max': File: numba\np\arraymath.py: Line 541.
    With argument(s): '(array(float64, 3d, C), axis=int64)':
   Rejected as the implementation raised a specific error:
     TypingError: got an unexpected keyword argument 'axis'
  raised from C:\Python310\lib\site-packages\numba\core\typing\templates.py:784

During: resolving callee type: Function(<function amax at 0x0000014E306D3370>)
During: typing of call at <ipython-input-1-b3894b8b12b8> (10)


File "<ipython-input-1-b3894b8b12b8>", line 10:
def max_per_cell(arr):
    return np.max(arr, axis=-1)
    ^

如何解决这个问题?

python numpy numba jit
1个回答
0
投票

在没有

np.max()
的情况下实现这个相当简单,而是使用循环:

@nb.njit()
def max_per_cell_nb(arr):
    ret = np.empty(arr.shape[:-1], dtype=arr.dtype)
    n, m = ret.shape
    for i in range(n):
        for j in range(m):
            max_ = arr[i, j, 0]
            max_ = max(max_, arr[i, j, 1])
            max_ = max(max_, arr[i, j, 2])
            ret[i, j] = max_
    return ret

对其进行基准测试,结果比

np.max(arr, axis=-1)
快约 16 倍。

%timeit max_per_cell_nb(img)
4.88 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit max_per_cell(img)
81 ms ± 654 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

在对此进行基准测试时,我做出了以下假设:

  • 图像为 1920x1080x3。 (换句话说,这是一个大图像。)
  • 图像数组是 C 顺序而不是 Fortran 顺序。如果是Fortran顺序,我的方法速度下降到7ms,而
    np.max()
    的速度更快,只需要15ms。请参阅检查 numpy 数组是否连续?了解如何判断数组是 C 顺序还是 Fortran 顺序。您的
    np.random.random((3, 4, 3))
    示例是 C 连续的。
  • 我将此函数与关闭 Numba JIT 的
    np.max(arr, axis=-1)
    进行比较,因为它无法真正优化对 NumPy 函数的单个调用。
© www.soinside.com 2019 - 2024. All rights reserved.