如何在 numba 中包装外部函数以使生成的函数可缓存?

问题描述 投票:0回答:1

我想用 numba 包装外部函数,但要求能够使用

njit(cache=True)
缓存生成的函数,就像我可以使用 numba 实现一样(我仅使用
dgesv
作为示例):

import numba as nb
import numpy as np

@nb.njit(cache=True)
def dgesv_numba(A, b):
    return np.linalg.solve(A, b)

我尝试过使用ctypes

import ctypes as ct
from ctypes.util import find_library

from numba import types
from numba.core import cgutils
from numba.extending import intrinsic

@intrinsic
def ptr_from_val(typingctx, data):
    # from https://stackoverflow.com/questions/51541302/how-to-wrap-a-cffi-function-in-numba-taking-pointers
    def impl(context, builder, signature, args):
        ptr = cgutils.alloca_once_value(builder, args[0])
        return ptr

    sig = types.CPointer(data)(data)
    return sig, impl

ptr_int = ct.POINTER(ct.c_int)
ptr_double = ct.POINTER(ct.c_double)

argtypes = [
    ptr_int,  # n
    ptr_int,  # nrhs
    ptr_double,  # a
    ptr_int,  # lda
    ptr_int,  # ipiv
    ptr_double,  # b
    ptr_int,  # ldb
    ptr_int,  # info
]
lapack_ctypes = ct.CDLL(find_library("lapack"))
_dgesv_ctypes = lapack_ctypes.dgesv_
_dgesv_ctypes.argtypes = argtypes
_dgesv_ctypes.restype = None

# Or get it from scipy
# addr = nb.extending.get_cython_function_address(
#     "scipy.linalg.cython_lapack", "dgesv"
# )
# functype = ct.CFUNCTYPE(None, *argtypes)
# _dgesv_ctypes = functype(addr)

@nb.njit(cache=True)
def args(A, b):
    if b.ndim == 1:
        _b = b[:, None]  # .reshape(-1, 1) # change to reshape numba < 0.57
        nrhs = np.int32(1)
    else:
        _b = b.T.copy()  # Dunno? is there a better way to do this?
        nrhs = np.int32(b.shape[1])

    n = np.int32(A.shape[0])
    info = np.int32(0)
    ipiv = np.zeros((n,), dtype=np.int32)
    return _b, n, nrhs, ipiv, info


@nb.njit(cache=True)
def dgesv_ctypes(A, b):
    b, n, nrhs, ipiv, info = args(A, b)
    _dgesv_ctypes(
        ptr_from_val(n),
        ptr_from_val(nrhs),
        A.T.copy().ctypes,  # Dunno? is there a better way to do this?
        ptr_from_val(n),
        ipiv.ctypes,
        b.ctypes,
        ptr_from_val(n),
        ptr_from_val(info),
    )
    if info:
        raise Exception("something went wrong")
    return b.T

并与 cffi:

import cffi
ffi = cffi.FFI()
ffi.cdef(
    """
void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb,
    int *info);
"""
)
lapack_cffi = ffi.dlopen(find_library("lapack"))
_dgesv_cffi = lapack_cffi.dgesv_


@nb.njit(cache=True)
def dgesv_cffi(A, b):
    b, n, nrhs, ipiv, info = args(A, b)
    _dgesv_cffi(
        ptr_from_val(n),
        ptr_from_val(nrhs),
        ffi.from_buffer(A.T.copy()),
        ptr_from_val(n),
        ffi.from_buffer(ipiv),
        ffi.from_buffer(b),
        ptr_from_val(n),
        ptr_from_val(info),
    )
    if info:
        raise Exception("something went wrong")
    return b.T

但在这两种情况下,我都会收到警告,提示该函数无法缓存,因为我使用了 ctypes 指针:

/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:79: NumbaWarning: Cannot cache compiled function "dgesv_ctypes" as it uses dynamic globals (such as ctypes pointers and large global arrays)
  @nb.njit(cache=True)
/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:97: NumbaWarning: Cannot cache compiled function "dgesv_cffi" as it uses dynamic globals (such as ctypes pointers and large global arrays)
  @nb.njit(cache=True)

我已经成功地通过WAP做到了:

class Dgesv(nb.types.WrapperAddressProtocol):
    def __wrapper_address__(self):
        return ct.cast(lapack_ctypes.dgesv_, ct.c_voidp).value

    def signature(self):
        return nb.types.void(
            nb.types.CPointer(nb.int32),  # n
            nb.types.CPointer(nb.int32),  # nrhs
            nb.types.CPointer(nb.float64),  # a
            nb.types.CPointer(nb.int32),  # lda
            nb.types.CPointer(nb.int32),  # ipiv
            nb.types.CPointer(nb.float64),  # b
            nb.types.CPointer(nb.int32),  # ldb
            nb.types.CPointer(nb.int32),  # info
        )


@nb.njit(cache=True)
def dgesv_wap(f, A, b):
    b, n, nrhs, ipiv, info = args(A, b)
    f(
        ptr_from_val(n),
        ptr_from_val(nrhs),
        A.T.copy().ctypes,
        ptr_from_val(n),
        ipiv.ctypes,
        b.ctypes,
        ptr_from_val(n),
        ptr_from_val(info),
    )
    if info:
        raise Exception("something went wrong")
    return b.T

但是生成的函数比其他方法慢得多,这并不是我真正想要的,因为您必须将函数作为参数传递才能使缓存工作:

rng = np.random.default_rng()

for i in range(3, 5):
    N = i
    A = rng.random((N, N))
    x = rng.random((N, 1000))
    b = A @ x
    _ctypes = dgesv_ctypes(A.copy(), b.copy())
    _cffi = dgesv_cffi(A.copy(), b.copy())
    _wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
    _numba = dgesv_numba(A, b)
    assert np.allclose(_ctypes, _numba)
    assert np.allclose(_cffi, _numba)
    assert np.allclose(_wap, _numba)
    assert np.allclose(x, _numba)

print("all good")

%timeit dgesv_ctypes(A.copy(), b.copy())
%timeit dgesv_cffi(A.copy(), b.copy())
%timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
%timeit dgesv_numba(A, b)

输出:

all good
56.5 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
55.8 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
89.6 µs ± 2.57 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
59.7 µs ± 894 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

那么我该如何在保留其他实现的性能的同时做到这一点?

python numba
1个回答
0
投票

好吧,实际上是在文档中,你必须使用

types.ExternalFunction
(虽然文档字符串说它仅供内部使用)并使用llvmlite加载库,但是这个实现不能在
之外调用njit
装饰功能:

from ctypes.util import find_library
from llvmlite import binding
from numba import types, njit

binding.load_library_permanently(find_library("lapack"))

ptr_int = types.CPointer(types.int32)
ptr_double = types.CPointer(types.float64)

_dgesv = types.ExternalFunction("dgesv_", types.float64(
    ptr_int, #n
    ptr_int, # nrhs
    ptr_double, # a
    ptr_int, # lda
    ptr_int, # ipiv
    ptr_double, # b
    ptr_int, # ldb
    ptr_int, # info
))

@njit(cache=True)
def dgesv_external_function(A, b):
    b, n, nrhs, ipiv, info = args(A, b)
    _dgesv(
        ptr_from_val(n),
        ptr_from_val(nrhs),
        A.T.copy().ctypes,
        ptr_from_val(n),
        ipiv.ctypes,
        b.ctypes,
        ptr_from_val(n),
        ptr_from_val(info),
    )
    if info:
        raise Exception("something went wrong")
    return b.T

时间:

A = rng.random((5, 5))
x = rng.random((5, 1000))
b = A @ x

_ctypes = dgesv_ctypes(A.copy(), b.copy())
_cffi = dgesv_cffi(A.copy(), b.copy())
_wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
_numba = dgesv_numba(A, b)
_ext = dgesv_external_function(A.copy(), b.copy())
assert np.allclose(_ctypes, _numba)
assert np.allclose(_cffi, _numba)
assert np.allclose(_wap, _numba)
assert np.allclose(x, _numba)
assert np.allclose(_ext, _numba)

%timeit dgesv_ctypes(A.copy(), b.copy())
%timeit dgesv_cffi(A.copy(), b.copy())
%timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
%timeit dgesv_numba(A, b)
%timeit dgesv_external_function(A.copy(), b.copy())

输出:

43.3 µs ± 5.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
40.5 µs ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
115 µs ± 34.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
53.5 µs ± 5.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
17.4 µs ± 544 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.