Numba 无效使用带有类型参数的Function。

问题描述 投票:1回答:1

我使用Numba non-python模式和一些NumPy函数。

@njit
def invert(W, copy=True):
    '''
    Inverts elementwise the weights in an input connection matrix.
    In other words, change the from the matrix of internode strengths to the
    matrix of internode distances.

    If copy is not set, this function will *modify W in place.*

    Parameters
    ----------
    W : np.ndarray
        weighted connectivity matrix
    copy : bool

    Returns
    -------
    W : np.ndarray
        inverted connectivity matrix
    '''

    if copy:
        W = W.copy()
    E = np.where(W)
    W[E] = 1. / W[E]
    return W

在这个函数中。W 是一个矩阵。但我得到了以下错误。这可能与 W[E] = 1. / W[E] 行。

File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)
  File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
    raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, A), tuple(array(int64, 1d, C) x 2))

那么NumPy和Numba的正确使用方法是什么呢?我知道NumPy在矩阵计算上做得很好。在这种情况下,NumPy是否足够快,Numba是否提供了更多的加速?

python numpy numba
1个回答
3
投票

正如 FBruzzesi 在评论中提到的,代码不能编译的原因是你使用了 "花哨的索引",因为你的代码中的 EW[E] 的输出。np.where 并且是一个数组的元组。(这就解释了为什么会出现略微神秘的错误信息。Numba不知道如何使用... ... getitem即当其中一个输入是元组时,它不知道如何找到括号中的东西。)

Numba 实际上是支持单一维度上的花式索引(也叫 "高级索引")。,只是不能多维。在您的情况下,可以进行简单的修改:首先使用 ravel 来几乎不费吹灰之力地使你的数组成为一维数组,然后应用变换,然后用廉价的 reshape 返回。

@njit
def invert2(W, copy=True):
    if copy:
        W = W.copy()
    Z = W.ravel()
    E = np.where(Z)
    Z[E] = 1. / Z[E]
    return Z.reshape(W.shape)

但这仍然比它需要的慢,因为将计算传递给不必要的中间数组,而不是在遇到非零值时立即修改数组。简单的做一个循环会更快。

@njit 
def invert3(W, copy=True): 
    if copy: 
        W = W.copy() 
    Z = W.ravel() 
    for i in range(len(Z)): 
        if Z[i] != 0: 
            Z[i] = 1/Z[i] 
    return Z.reshape(W.shape) 

这段代码可以工作,而不管数组的尺寸是多少 W. 如果我们知道 W 是二维的,那么我们可以直接在这两个维度上进行迭代,但由于二者的性能相近,我还是选择更通用的路线。

在我的电脑上,假设一个300乘300的阵列,时序是 W 其中约有一半的条目是0,而其中 invert 是你在没有Numba编译的情况下的原始函数,是。

In [80]: %timeit invert(W)                                                                   
2.67 ms ± 49.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [81]: %timeit invert2(W)                                                                  
519 µs ± 24.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [82]: %timeit invert3(W)                                                                  
186 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

所以Numba给我们带来了相当可观的加速(在它已经被运行过一次以消除编译时间之后),尤其是在代码被重写成Numba可以利用的高效循环风格之后。

© www.soinside.com 2019 - 2024. All rights reserved.