我想编写依赖于可选参数的高效函数评估。 为了创建涉及Python dict的签名我遵循了这个问题的答案。
工作示例:
import numba
import numpy as np
from typing import Optional
from numba.core import types
from numba.typed import Dict
@numba.njit(
[
'float64[:](float64[:],DictType(unicode_type, float64))',
'float64[:](float64[:],none)',
]
)
def f(x: np.ndarray, p: Optional[dict[str, float]] = None) -> np.ndarray:
x_ = x.copy()
if p is not None:
for k, v in p.items():
x_ = x_ + float(v) # work-around for errors caused by x_ += v
return x_
x = np.array([1,2,3], dtype=np.float64)
p = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
)
p['a'] = 4.
p['b'] = 5.
print(f(x, p))
>>> [10., 11., 12.]
print(f(x))
>>> Exception has occurred: TypeError X
No matching definition for argument type(s) array(float64, 1d, C), omitted(default=None)
File "/example.py", line 31, in <module>
print(f(x))
^^^^
TypeError: No matching definition for argument type(s) array(float64, 1d, C), omitted(default=None)
调用不带可选参数的函数会引发上述错误。
我尝试了其他签名,例如
@numba.njit('float64[:](float64[:],optional(DictType(unicode_type, float64)))'
)
def f(x: np.ndarray, p: Optional[dict[str, float]] = None) -> np.ndarray:
和
@numba.njit('float64[:](float64[:],optional(DictType(unicode_type, float64)))'
)
def f(
x: np.ndarray,
p: dict[str, float] = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
)
) -> np.ndarray:
但没有成功。
Python版本:3.11.4 numba 版本:0.57.11
您可以省略签名,numba 将正确编译并运行代码(它只是不会预先编译函数):
from typing import Optional
import numba
from numba.core import types
from numba.typed import Dict
@numba.njit
def f(x: np.ndarray, p: Optional[dict[str, float]] = None) -> np.ndarray:
x_ = x.copy()
if p is not None:
for k, v in p.items():
x_ = x_ + float(v) # work-around for errors caused by x_ += v
return x_
x = np.array([1, 2, 3], dtype=np.float64)
p = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
)
p["a"] = 4.0
p["b"] = 5.0
print(f(x, p))
print(f(x))
打印:
[10. 11. 12.]
[1. 2. 3.]