在 Numba 中使用 `numpy.random.normal()` 时出错

问题描述 投票:0回答:1

我正在探索一些 Numba 来优化一些信号处理代码。根据 Numba 的文档,即时编译器很好地支持

numpy.random
包中的函数。然而,当我跑步时

import numpy as np
from numba import jit

@jit(nopython=True)
def numba():
    noise = np.random.normal(size=100)

# ...

if __name__ == "__main__":
    numba()

我收到以下错误:

Traceback (most recent call last):
  File ".../test.py", line 89, in <module>
    numba()
  File ".../venv/lib/python3.9/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File ".../venv/lib/python3.9/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in method normal of numpy.random.mtrand.RandomState object at 0x104355740>) found for signature:
 
 >>> normal(size=Literal[int](100000))
 
There are 4 candidate implementations:
  - Of which 4 did not match due to:
  Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba/core/overload_glue.py: Line 129.
    With argument(s): '(size=int64)':
   Rejected as the implementation raised a specific error:
     TypingError: unsupported call signature
  raised from .../venv/lib/python3.9/site-packages/numba/core/typing/templates.py:439

During: resolving callee type: Function(<built-in method normal of numpy.random.mtrand.RandomState object at 0x104355740>)
During: typing of call at .../test.py (65)


File "test.py", line 65:
def numba():
    noise = np.random.normal(size=SIZE)
    ^

我在做一些明显愚蠢的事情吗?

python numpy numba
1个回答
1
投票

如果您检查文档的当前状态,则尚不支持大小参数。

由于 numba 将其编译为机器代码,因此在速度方面是等效的。

@jit(nopython=True)
def numba():
    noise = np.empty(100,dtype=np.float64)
    for i in range(100):
        noise[i] = np.random.normal()
    return noise

编辑: numba 版本实际上快两倍......可能是因为它不解析输入。

© www.soinside.com 2019 - 2024. All rights reserved.