我如何创建具有可调参数的numba可调用项?

问题描述 投票:1回答:1

我想创建一个由numba编译的python可调用(一个我可以在另一个由Numba编译的函数中使用的函数),它具有一个可以调整以影响函数调用结果的内部数组。在纯python中,这将对应于具有__call__方法的类:

class Test:
    def __init__(self, arr):
        self.arr = arr
    def __call__(self, idx):
        res = 0
        for i in idx:
            res += self.arr[i]
        return res

t = Test([0, 1, 2])
print(t([1, 2]))
t.arr = [1, 2, 3]
print(t([1, 2]))

分别打印35,因此在修改内部数组arr后结果是不同的。

[使用jitclass和numpy数组的文字转换为Numba的样子,

import numpy as np
import numba as nb

@nb.jitclass([('arr', nb.double[:])])
class Test:
    def __init__(self, arr):
        self.arr = arr.astype(np.double)
    def __call__(self, idx):
        res = 0
        for i in idx:
            res += self.arr[i]
        return res

t = Test(np.arange(3))
print(t(np.array([1, 2])))
t.arr = np.arange(3) + 1
print(t(np.array([1, 2])))

不幸的是,这在TypeError: 'Test' object is not callable上失败,因为Numba似乎还不支持__call__

然后我尝试使用闭包来解决问题

import numpy as np
import numba as nb

arr = np.arange(5)

@nb.jit
def call(idx):
    res = 0
    for i in idx:
        res += arr[i]
    return res


print(call(np.array([1, 2])))
arr += 1
print(call(np.array([1, 2])))

但是这会打印两次3,因为闭包将arr中的数据复制到一个内部表示中,所以我不能(轻松地)从外部更改它。我什至尝试通过在与ctypes组合的Numpy数组上使用numba.carray指针来欺骗Numba,但Numba似乎仍在复制数据,因此我无法对其进行操作。

我了解Numba希望控制内存并避免访问可能不再使用的内存区域。但是,我有一个特定的用例,我想避免传递多余的数组arr,而是以某种方式调整内部副本。有什么办法可以做到这一点?

python numba
1个回答
0
投票
将问题分为两部分,一个numba函数和一个纯python类:

import numpy as np import numba @numba.jit def calc(arr, idx): res = 0 for i in idx: res += arr[i] return res class Test: def __init__(self, arr): self.arr = arr.astype(np.double) def __call__(self, idx): return calc(self.arr, idx) t = Test(np.arange(3)) print(t(np.array([1, 2]))) t.arr = np.arange(3) + 1 print(t(np.array([1, 2])))

© www.soinside.com 2019 - 2024. All rights reserved.