我想用 numba 包装外部函数,但要求能够使用
njit(cache=True)
缓存生成的函数,就像我可以使用 numba 实现一样(我仅使用 dgesv
作为示例):
import numba as nb
import numpy as np
@nb.njit(cache=True)
def dgesv_numba(A, b):
return np.linalg.solve(A, b)
我尝试过使用ctypes:
import ctypes as ct
from ctypes.util import find_library
from numba import types
from numba.core import cgutils
from numba.extending import intrinsic
@intrinsic
def ptr_from_val(typingctx, data):
# from https://stackoverflow.com/questions/51541302/how-to-wrap-a-cffi-function-in-numba-taking-pointers
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(data)(data)
return sig, impl
ptr_int = ct.POINTER(ct.c_int)
ptr_double = ct.POINTER(ct.c_double)
argtypes = [
ptr_int, # n
ptr_int, # nrhs
ptr_double, # a
ptr_int, # lda
ptr_int, # ipiv
ptr_double, # b
ptr_int, # ldb
ptr_int, # info
]
lapack_ctypes = ct.CDLL(find_library("lapack"))
_dgesv_ctypes = lapack_ctypes.dgesv_
_dgesv_ctypes.argtypes = argtypes
_dgesv_ctypes.restype = None
# Or get it from scipy
# addr = nb.extending.get_cython_function_address(
# "scipy.linalg.cython_lapack", "dgesv"
# )
# functype = ct.CFUNCTYPE(None, *argtypes)
# _dgesv_ctypes = functype(addr)
@nb.njit(cache=True)
def args(A, b):
if b.ndim == 1:
_b = b[:, None] # .reshape(-1, 1) # change to reshape numba < 0.57
nrhs = np.int32(1)
else:
_b = b.T.copy() # Dunno? is there a better way to do this?
nrhs = np.int32(b.shape[1])
n = np.int32(A.shape[0])
info = np.int32(0)
ipiv = np.zeros((n,), dtype=np.int32)
return _b, n, nrhs, ipiv, info
@nb.njit(cache=True)
def dgesv_ctypes(A, b):
b, n, nrhs, ipiv, info = args(A, b)
_dgesv_ctypes(
ptr_from_val(n),
ptr_from_val(nrhs),
A.T.copy().ctypes, # Dunno? is there a better way to do this?
ptr_from_val(n),
ipiv.ctypes,
b.ctypes,
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
并与 cffi:
import cffi
ffi = cffi.FFI()
ffi.cdef(
"""
void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb,
int *info);
"""
)
lapack_cffi = ffi.dlopen(find_library("lapack"))
_dgesv_cffi = lapack_cffi.dgesv_
@nb.njit(cache=True)
def dgesv_cffi(A, b):
b, n, nrhs, ipiv, info = args(A, b)
_dgesv_cffi(
ptr_from_val(n),
ptr_from_val(nrhs),
ffi.from_buffer(A.T.copy()),
ptr_from_val(n),
ffi.from_buffer(ipiv),
ffi.from_buffer(b),
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
但在这两种情况下,我都会收到警告,提示该函数无法缓存,因为我使用了 ctypes 指针:
/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:79: NumbaWarning: Cannot cache compiled function "dgesv_ctypes" as it uses dynamic globals (such as ctypes pointers and large global arrays)
@nb.njit(cache=True)
/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:97: NumbaWarning: Cannot cache compiled function "dgesv_cffi" as it uses dynamic globals (such as ctypes pointers and large global arrays)
@nb.njit(cache=True)
我已经成功地通过WAP做到了:
class Dgesv(nb.types.WrapperAddressProtocol):
def __wrapper_address__(self):
return ct.cast(lapack_ctypes.dgesv_, ct.c_voidp).value
def signature(self):
return nb.types.void(
nb.types.CPointer(nb.int32), # n
nb.types.CPointer(nb.int32), # nrhs
nb.types.CPointer(nb.float64), # a
nb.types.CPointer(nb.int32), # lda
nb.types.CPointer(nb.int32), # ipiv
nb.types.CPointer(nb.float64), # b
nb.types.CPointer(nb.int32), # ldb
nb.types.CPointer(nb.int32), # info
)
@nb.njit(cache=True)
def dgesv_wap(f, A, b):
b, n, nrhs, ipiv, info = args(A, b)
f(
ptr_from_val(n),
ptr_from_val(nrhs),
A.T.copy().ctypes,
ptr_from_val(n),
ipiv.ctypes,
b.ctypes,
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
但是生成的函数比其他方法慢得多,这并不是我真正想要的,因为您必须将函数作为参数传递才能使缓存工作:
rng = np.random.default_rng()
for i in range(3, 5):
N = i
A = rng.random((N, N))
x = rng.random((N, 1000))
b = A @ x
_ctypes = dgesv_ctypes(A.copy(), b.copy())
_cffi = dgesv_cffi(A.copy(), b.copy())
_wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
_numba = dgesv_numba(A, b)
assert np.allclose(_ctypes, _numba)
assert np.allclose(_cffi, _numba)
assert np.allclose(_wap, _numba)
assert np.allclose(x, _numba)
print("all good")
%timeit dgesv_ctypes(A.copy(), b.copy())
%timeit dgesv_cffi(A.copy(), b.copy())
%timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
%timeit dgesv_numba(A, b)
输出:
all good
56.5 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
55.8 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
89.6 µs ± 2.57 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
59.7 µs ± 894 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
那么我该如何在保留其他实现的性能的同时做到这一点?
好吧,实际上是在文档中,你必须使用
types.ExternalFunction
(虽然文档字符串说它仅供内部使用)并使用llvmlite加载库,但是这个实现不能在之外调用njit
装饰功能:
from ctypes.util import find_library
from llvmlite import binding
from numba import types, njit
binding.load_library_permanently(find_library("lapack"))
ptr_int = types.CPointer(types.int32)
ptr_double = types.CPointer(types.float64)
_dgesv = types.ExternalFunction("dgesv_", types.float64(
ptr_int, #n
ptr_int, # nrhs
ptr_double, # a
ptr_int, # lda
ptr_int, # ipiv
ptr_double, # b
ptr_int, # ldb
ptr_int, # info
))
@njit(cache=True)
def dgesv_external_function(A, b):
b, n, nrhs, ipiv, info = args(A, b)
_dgesv(
ptr_from_val(n),
ptr_from_val(nrhs),
A.T.copy().ctypes,
ptr_from_val(n),
ipiv.ctypes,
b.ctypes,
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
时间:
A = rng.random((5, 5))
x = rng.random((5, 1000))
b = A @ x
_ctypes = dgesv_ctypes(A.copy(), b.copy())
_cffi = dgesv_cffi(A.copy(), b.copy())
_wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
_numba = dgesv_numba(A, b)
_ext = dgesv_external_function(A.copy(), b.copy())
assert np.allclose(_ctypes, _numba)
assert np.allclose(_cffi, _numba)
assert np.allclose(_wap, _numba)
assert np.allclose(x, _numba)
assert np.allclose(_ext, _numba)
%timeit dgesv_ctypes(A.copy(), b.copy())
%timeit dgesv_cffi(A.copy(), b.copy())
%timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
%timeit dgesv_numba(A, b)
%timeit dgesv_external_function(A.copy(), b.copy())
输出:
43.3 µs ± 5.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
40.5 µs ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
115 µs ± 34.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
53.5 µs ± 5.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
17.4 µs ± 544 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)