Numba 异常:无法确定 <class 'type'>

问题描述 投票:0回答:1

出于性能原因,我想将函数转换为 Numba。我的 MWE 示例如下。如果我删除

@njit
装饰器,代码可以工作,但是对于
@njit
,我会收到运行时异常。异常很可能是由于
dtype=object
来定义
result_arr
但我也尝试使用
dtype=float64
,但我得到了类似的异常。

import numpy as np
from numba import njit
from timeit import timeit

######-----------Required NUMBA function----------###
#@njit #<----without this, the code works
def required_numba_function():
    nRows = 151
    nCols = 151
    nFrames = 24
    result_arr = np.empty((151* 151 * 24), dtype=object)

    for frame in range(nFrames):
        for row in range(nRows):
            for col in range(nCols):        
                size_rows = np.random.randint(8, 15) 
                size_cols = np.random.randint(2, 6)            
                args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
                flat_idx = frame * (nRows * nCols) + (row * nCols + col)
                result_arr[flat_idx] = args

    return result_arr

######------------------main()-------##################
if __name__ == "__main__":
    required_numba_function()

    print() 

如何解决 Numba 异常?

python numba
1个回答
0
投票

正如您所说,数组列表很好,您可以将

result_array
的分配替换为
dtype=object
的空数组,并在每次迭代时附加到空列表 - 这与 numba 兼容:

@nb.njit
def required_numba_function2():
    np.random.seed(0) # Just for testing, you seem to have to set the seed within the function for numba to be aware of it
    nRows = 151
    nCols = 151
    nFrames = 24
    result_arr = []

    for frame in range(nFrames):
        for row in range(nRows):
            for col in range(nCols):        
                size_rows = np.random.randint(8, 15) 
                size_cols = np.random.randint(2, 6)            
                args = np.random.normal(3, 2.5, size=(size_rows, size_cols)) # size is random
                result_arr.append(args)

    return result_arr

测试

np.random.seed(0)
result = required_numba_function()
result2 = required_numba_function2()
for i, j in zip(result, result2):
    assert np.allclose(i, j)

时间:

%timeit required_numba_function()
%timeit required_numba_function2()
2.08 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
606 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
© www.soinside.com 2019 - 2024. All rights reserved.