我想创建一个由numba编译的python可调用(一个我可以在另一个由Numba编译的函数中使用的函数),它具有一个可以调整以影响函数调用结果的内部数组。在纯python中,这将对应于具有__call__
方法的类:
class Test:
def __init__(self, arr):
self.arr = arr
def __call__(self, idx):
res = 0
for i in idx:
res += self.arr[i]
return res
t = Test([0, 1, 2])
print(t([1, 2]))
t.arr = [1, 2, 3]
print(t([1, 2]))
分别打印3
和5
,因此在修改内部数组arr
后结果是不同的。
[使用jitclass
和numpy数组的文字转换为Numba的样子,
import numpy as np
import numba as nb
@nb.jitclass([('arr', nb.double[:])])
class Test:
def __init__(self, arr):
self.arr = arr.astype(np.double)
def __call__(self, idx):
res = 0
for i in idx:
res += self.arr[i]
return res
t = Test(np.arange(3))
print(t(np.array([1, 2])))
t.arr = np.arange(3) + 1
print(t(np.array([1, 2])))
不幸的是,这在TypeError: 'Test' object is not callable
上失败,因为Numba似乎还不支持__call__
。
然后我尝试使用闭包来解决问题
import numpy as np
import numba as nb
arr = np.arange(5)
@nb.jit
def call(idx):
res = 0
for i in idx:
res += arr[i]
return res
print(call(np.array([1, 2])))
arr += 1
print(call(np.array([1, 2])))
但是这会打印两次3
,因为闭包将arr
中的数据复制到一个内部表示中,所以我不能(轻松地)从外部更改它。我什至尝试通过在与ctypes
组合的Numpy数组上使用numba.carray
指针来欺骗Numba,但Numba似乎仍在复制数据,因此我无法对其进行操作。
我了解Numba希望控制内存并避免访问可能不再使用的内存区域。但是,我有一个特定的用例,我想避免传递多余的数组arr
,而是以某种方式调整内部副本。有什么办法可以做到这一点?
import numpy as np
import numba
@numba.jit
def calc(arr, idx):
res = 0
for i in idx:
res += arr[i]
return res
class Test:
def __init__(self, arr):
self.arr = arr.astype(np.double)
def __call__(self, idx):
return calc(self.arr, idx)
t = Test(np.arange(3))
print(t(np.array([1, 2])))
t.arr = np.arange(3) + 1
print(t(np.array([1, 2])))