我想创建一个 numba jitclass,其属性包含任何 jitted 函数。
# simple jitted functions defined in another file
@njit
def my_function(x):
x = x + 1
return x
@njit
def another_function(x):
x = x * 2
return x
spec = [('attribute', ???),
('value', float32)]
@jitclass(spec)
class Myclass:
def __init__(self, fun):
self.attribute = fun
def class_fun(self, x):
value = self.attribute(x)
return value
当我将
???
替换为 numba.typeof(my_function)
并使用 an_object = Myclass(fun=my_function)
创建 Myclass 的实例时,一切正常。但是,那么我只能传递确切的函数my_function
。当我使用另一个 jitted 函数创建 Myclass 对象时,
new_object = Myclass(fun=another_function)
new_object.class_fun(2)
我收到以下错误:
Failed in nopython mode pipeline (step: nopython mode backend)
Cannot cast type(CPUDispatcher(<function another_function at 0x7fb1e3d8ce50>))
to type(CPUDispatcher(<function my_function at 0x7fb1e404de50>))
这对我来说很有意义,因为我为 my_function 定义了字段类型。我不知道如何以通用方式定义字段类型
attribute
,以便我可以传递任何函数。
有谁知道一种方法可以在创建 jitclass 的对象时传递任何 jitted 函数
Myclass
?
答复可能会延迟,似乎可以通过numba.cfunc
购买这是一个片段:
import numba as nb
@nb.cfunc("double(double)")
def a(x):
return x + 1.0
@nb.cfunc("double(double)")
def b(x):
return x + 2.0
@nb.experimental.jitclass
class TestFunc:
rnd_fun: nb.typeof(a)
def __init__(self, func):
self.rnd_fun = func
def call(self, x):
return self.rnd_fun(x)
> t1 = TestFunc(a)
> t1.call(10)
11.0
> t2 = TestFunc(b)
> t2.call(10)
12