Numba 无法使用接受 Numpy 数组参数的构造函数来编译 `jitclass`

问题描述 投票:0回答:1

以下

evaluate
的实现可以正确编译:

import numpy as np
import numpy.typing as npt
from numba import njit
from numba.experimental import jitclass

独立功能

@njit
def evaluate(x : npt.NDArray[np.float64], m : float, b : float) -> npt.NDArray[np.float64]:
    return m * x + b

x = np.linspace(0, 100)
y = evaluate(x, 2, 3)

@jitclass
具有独立功能

@jitclass
class LineEvaluator:
    def __init__(self):
        ...

    def evaluate(self, x : npt.NDArray[np.float64], m : float, b : float) -> npt.NDArray[np.float64]:
        return m * x + b

x = np.linspace(0, 100)
y = LineEvaluator().evaluate(x, 2, 3)

但是以下实现无法编译并出现错误:

@jitclass
构造函数中带有参数

@jitclass
class LineEvaluator:
    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        self.m = m
        self.b = b

    def evaluate(self) -> npt.NDArray[np.float64]:
        return self.m * self.x + self.b

x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
Failed in nopython mode pipeline (step: nopython frontend)
Cannot resolve setattr: (instance.jitclass.LineEvaluator#117caf610<>).x = array(float64, 1d, C)

File "test.py", line 9:
    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        ^

During: typing of set attribute 'x' at /private/tmp/test.py (9)

File "test.py", line 9:
    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        ^

During: resolving callee type: jitclass.LineEvaluator#117caf610<>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.LineEvaluator#117caf610<>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>

@jitclass
成员类型注释

@jitclass
class LineEvaluator:
    x : npt.NDArray[np.float64]
    m : float
    b : float

    def __init__(self, x : npt.NDArray[np.float64], m : float, b : float):
        self.x = x
        self.m = m
        self.b = b

    def evaluate(self) -> npt.NDArray[np.float64]:
        return self.m * self.x + self.b

x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
Traceback (most recent call last):
  File "/private/tmp/test.py", line 7, in <module>
    class LineEvaluator:
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/decorators.py", line 88, in jitclass
    return wrap(cls_or_spec)
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/decorators.py", line 77, in wrap
    cls_jitted = register_class_type(cls, spec, types.ClassType,
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/experimental/jitclass/base.py", line 180, in register_class_type
    spec[attr] = as_numba_type(py_type)
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/core/typing/asnumbatype.py", line 121, in __call__
    return self.infer(py_type)
  File "/Users/xx/miniconda3/envs/xx/lib/python3.10/site-packages/numba/core/typing/asnumbatype.py", line 115, in infer
    raise errors.TypingError(
numba.core.errors.TypingError: Cannot infer numba type of python type numpy.ndarray[typing.Any, numpy.dtype[numpy.float64]]

错误消息非常不透明,为什么在这种特定情况下编译失败?

谢谢!

python numpy jit
1个回答
0
投票

您可以在

spec=
中定义
@jitclass
:

import numba as nb
import numpy as np
from numba.experimental import jitclass


@jitclass(spec=[("x", nb.float64[:]), ("m", nb.float64), ("b", nb.float64)])
class LineEvaluator:
    def __init__(self, x, m, b):
        self.x = x
        self.m = m
        self.b = b

    def evaluate(self):
        return self.m * self.x + self.b


x = np.linspace(0, 100)
y = LineEvaluator(x, 2, 3).evaluate()
print(y)

打印:

[  3.           7.08163265  11.16326531  15.24489796  19.32653061
  23.40816327  27.48979592  31.57142857  35.65306122  39.73469388
  43.81632653  47.89795918  51.97959184  56.06122449  60.14285714
  64.2244898   68.30612245  72.3877551   76.46938776  80.55102041
  84.63265306  88.71428571  92.79591837  96.87755102 100.95918367
 105.04081633 109.12244898 113.20408163 117.28571429 121.36734694
 125.44897959 129.53061224 133.6122449  137.69387755 141.7755102
 145.85714286 149.93877551 154.02040816 158.10204082 162.18367347
 166.26530612 170.34693878 174.42857143 178.51020408 182.59183673
 186.67346939 190.75510204 194.83673469 198.91836735 203.        ]
© www.soinside.com 2019 - 2024. All rights reserved.