我正在构建一个“编译器”类,从现有代码生成 numba njit-ed 函数
import numba
from typing import Callable
class Foo:
def __init__(self) -> None:
self.compiled: dict[str, Callable] = dict()
self.compile_bar()
def compile_bar(self):
@numba.njit
def func1(a, b):
return a * b
@numba.njit
def func2(a, b):
return a / b
l_c: list[Callable] = [func1, func2]
@numba.njit
def wrapper(a, b):
for func in l_c:
print(func(a, b))
self.compiled.update({'bar': wrapper})
F = Foo()
print(F.compiled['bar'](1.,2.))
最后一行抛出错误
Exception has occurred: NumbaNotImplementedError X
Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7f996cd00c50>, (List(type(CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>)), True),)
During: lowering "$6load_deref.0 = freevar(l_c: [CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>), CPUDispatcher(<function Foo.compile_bar.<locals>.func2 at 0x7f99264c2de0>)])" at /workdir/porepy/scripts_vl/playground.py (33)
File "/workdir/scripts/playground.py", line 40, in <module>
print(F.compiled['bar'](1.,2.))
^^^^^^^^^^^^^^^^^^^^^^^^
numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7f996cd00c50>, (List(type(CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>)), True),)
During: lowering "$6load_deref.0 = freevar(l_c: [CPUDispatcher(<function Foo.compile_bar.<locals>.func1 at 0x7f996bcbcea0>), CPUDispatcher(<function Foo.compile_bar.<locals>.func2 at 0x7f99264c2de0>)])" at /workdir/scripts/playground.py (33)
现在,
func1
和func2
只是保持代码紧凑的示例。一般来说,它是基于一些输入参数创建的 njit 函数列表 Foo
我尝试制作一个打字列表
l_c = numba.typed.List.empty_list(numba.float64(numba.float64, numba.float64).as_type())
l_c.append(func1)
l_c.append(func2)
但结果并没有改变。
我认为问题在于通过列表索引访问函数,因为它只发生在那里。
访问不在列表中的函数似乎工作正常。
错误消息对我来说相当神秘。
有人有解决方案和解释吗?
有两个不同的问题:首先,列表通常被认为是异体,这与函数指针不能很好地配合。使用元组代替:
l_c: tuple[Callable, ...] = (func1, func2)
第二个问题是 numba 在迭代函数之前无法完全弄清楚函数的类型。要解决此问题,您可以为至少一个函数提供明确的签名:
@numba.njit("double(double, double)")
def func1(a, b):
return a * b
或使用
literal_unroll
:
@numba.njit
def wrapper(a, b):
for func in literal_unroll(l_c):
print(func(a, b))
所有这些选项都会产生很多关于第一类函数类型是实验性功能的警告,但是,在这种情况下它似乎有效。如果要动态构建函数列表,则需要在编译函数之前执行此操作。只需将其构建为列表,然后将其转换为元组即可。
要消除警告,请在代码中的某处添加
warnings.simplefilter("ignore", NumbaExperimentalFeatureWarning)
。