如何在 numba jitted 函数中使用类的对象而不抖动整个类?

问题描述 投票:0回答:1

我有一个类,它创建一个对象,该对象包含多个数组(numpy 数组)作为它们的子对象。该类包含构建这些数组的所有复杂逻辑。只有一个属性的简单模型如下所示:

class System():
    def __init__(self,backend='numpy'):
        if backend == 'numpy':
            self.D = np.ones((2,2))
        else:
            # Mock-up for other backend types
            self.D = [[1,1],[1,1]]

在实际代码中,这会向类添加多个属性,这些属性是在类初始化时在运行时确定的。如果后端是例如

'numpy'
,那么所有属性都将是 numpy 数组。为了简单起见,我在这里只添加一个属性
D

此类的目的是用户可以利用他提供的函数中的属性,因为它们很可能包含大量循环,因此使 njit 显着加快代码速度。我现在希望用户能够执行他的功能。再次,一个模拟示例:

### User code
a = System(backend='numpy')

@nb.njit()
def user_provided_function():
    result = a.D * 2
    return result

out = user_provided_function()
print(out)

显然,这不起作用,因为 numba 会抱怨该函数无法被 njitted,因为

a
的类型未定义/jitable。更清楚的是,如果用户避免使用全局变量的坏习惯并使用以下代码,则这将不起作用:

### User code
@nb.njit()
def user_provided_function(a):
    result = a.D * 2
    return result

b = System(backend='numpy')
out = user_provided_function(b)
print(out)

我的主要问题是,我无法从上面的类中创建 jitclass,因为有些后端与 numba 不兼容。然而,我喜欢这个类,我只有一段代码,可以提供多个后端,并且用户可以轻松更改后端,而无需重构整个代码。

有什么干净的方法可以避免这个问题?我喜欢在用户代码中输入

a.D
的美妙感觉,并希望为用户保持尽可能干净的界面。

python design-patterns numba
1个回答
0
投票

我建议不要将 python 对象传递给 numba 函数,而是传递 numpy 数组。

以 numba 编译函数的方式思考它们只是接受数组并返回新数组/修改这些数组,仅此而已。这将使您的程序保持简单:

import numba as nb
import numpy as np


class System:
    def __init__(self, backend="numpy"):
        if backend == "numpy":
            self.D = np.ones((2, 2), dtype=np.float32)
        else:
            # Mock-up for other backend types
            self.D = [[1, 1], [1, 1]]


@nb.njit("float32[:, :](float32[:, :])")
def user_provided_function(a):
    return a * 2


b = System(backend="numpy")
out = user_provided_function(b.D)  # <-- don't pass whole object, just array from this object
print(out)

打印:

[[2. 2.]
 [2. 2.]]
© www.soinside.com 2019 - 2024. All rights reserved.