numba njit 带有可选字典的签名

问题描述 投票:0回答:1

我想编写依赖于可选参数的高效函数评估。 为了创建涉及Python dict的签名我遵循了这个问题的答案

工作示例:

import numba
import numpy as np
from typing import Optional

from numba.core import types
from numba.typed import Dict

@numba.njit(
    [
        'float64[:](float64[:],DictType(unicode_type, float64))',
        'float64[:](float64[:],none)',
    ]
)
def f(x: np.ndarray, p: Optional[dict[str, float]] = None) -> np.ndarray:
    x_ = x.copy()
    if p is not None:
        for k, v in p.items():
            x_ = x_ + float(v)  # work-around for errors caused by x_ += v

    return x_

x = np.array([1,2,3], dtype=np.float64)
p = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,
)
p['a'] = 4.
p['b'] = 5.

print(f(x, p))
>>> [10., 11., 12.]
print(f(x))
>>> Exception has occurred: TypeError X
No matching definition for argument type(s) array(float64, 1d, C), omitted(default=None)
  File "/example.py", line 31, in <module>
    print(f(x))
          ^^^^
TypeError: No matching definition for argument type(s) array(float64, 1d, C), omitted(default=None)

调用不带可选参数的函数会引发上述错误。

我尝试了其他签名,例如

@numba.njit('float64[:](float64[:],optional(DictType(unicode_type, float64)))'
)
def f(x: np.ndarray, p: Optional[dict[str, float]] = None) -> np.ndarray:

@numba.njit('float64[:](float64[:],optional(DictType(unicode_type, float64)))'
)
def f(
    x: np.ndarray,
    p: dict[str, float] = Dict.empty(
        key_type=types.unicode_type,
        value_type=types.float64,
    )
) -> np.ndarray:

但没有成功。

Python版本:3.11.4 numba 版本:0.57.11

python-3.x numba
1个回答
0
投票

您可以省略签名,numba 将正确编译并运行代码(它只是不会预先编译函数):

from typing import Optional

import numba
from numba.core import types
from numba.typed import Dict


@numba.njit
def f(x: np.ndarray, p: Optional[dict[str, float]] = None) -> np.ndarray:
    x_ = x.copy()
    if p is not None:
        for k, v in p.items():
            x_ = x_ + float(v)  # work-around for errors caused by x_ += v
    return x_


x = np.array([1, 2, 3], dtype=np.float64)
p = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,
)
p["a"] = 4.0
p["b"] = 5.0

print(f(x, p))
print(f(x))

打印:

[10. 11. 12.]
[1. 2. 3.]
© www.soinside.com 2019 - 2024. All rights reserved.