当 numba jitclass 包含 jitted 函数时,如何指定它的字段?

问题描述 投票:0回答:1

我想创建一个 numba jitclass,其属性包含任何 jitted 函数。

# simple jitted functions defined in another file
@njit
def my_function(x):
    x = x + 1
    return x

@njit
def another_function(x):
    x = x * 2
    return x

spec = [('attribute', ???),
        ('value', float32)]

@jitclass(spec)
class Myclass:
    def __init__(self, fun):
        self.attribute = fun

    def class_fun(self, x):
       value = self.attribute(x)
       return value

当我将

???
替换为
numba.typeof(my_function)
并使用
an_object = Myclass(fun=my_function)
创建 Myclass 的实例时,一切正常。但是,那么我只能传递确切的函数
my_function
。当我使用另一个 jitted 函数创建 Myclass 对象时,

new_object = Myclass(fun=another_function)
new_object.class_fun(2)

我收到以下错误:

Failed in nopython mode pipeline (step: nopython mode backend)
Cannot cast type(CPUDispatcher(<function another_function at 0x7fb1e3d8ce50>))
to type(CPUDispatcher(<function my_function at 0x7fb1e404de50>))

这对我来说很有意义,因为我为 my_function 定义了字段类型。我不知道如何以通用方式定义字段类型

attribute
,以便我可以传递任何函数。

有谁知道一种方法可以在创建 jitclass 的对象时传递任何 jitted 函数

Myclass

python class types numba
1个回答
0
投票

答复可能会延迟,似乎可以通过numba.cfunc

购买

这是一个片段:

import numba as nb

@nb.cfunc("double(double)")
def a(x):
    return x + 1.0

@nb.cfunc("double(double)")
def b(x):
    return x + 2.0
    
    
@nb.experimental.jitclass
class TestFunc:
    rnd_fun: nb.typeof(a)

    def __init__(self, func):
        self.rnd_fun = func

    def call(self, x):
        return self.rnd_fun(x)

> t1 = TestFunc(a)
> t1.call(10)
11.0

> t2 = TestFunc(b)
> t2.call(10)
12
© www.soinside.com 2019 - 2024. All rights reserved.