将函数对象用作numba njit函数的参数

问题描述 投票:3回答:1

我想制作一个通用函数,该函数以一个函数对象作为参数。

最简单的情况之一:

import numpy as np
import numba as nb
@nb.njit()
def test(a, f=np.median):
    return f(a)

test(np.arange(10), np.mean)

给出错误,尽管test(np.arange(10))正常工作。

错误:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
[1] During: typing of argument at <ipython-input-54-52cead0f097d> (5)

File "<ipython-input-54-52cead0f097d>", line 5:
def test(a, f=np.median):
    return f(a)
    ^

This error may have been caused by the following argument(s):
- argument 1: cannot determine Numba type of <class 'function'>

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

这是不允许的,还是我缺少什么?

python numba
1个回答
2
投票

使用函数作为参数在使用numba时比较棘手,而且非常昂贵。在Frequently Asked Questions: "1.18.1.1. Can I pass a function as an argument to a jitted function?"中提到:

1.18.1.1。我可以将函数作为实参传递给参数吗?

从Numba 0.39开始,只要函数参数也已JIT编译,就可以:

@jit(nopython=True)
def f(g, x):
    return g(x) + g(-x)
result = f(jitted_g_function, 1)

但是,使用参数作为函数进行分派会产生额外的开销。如果这对您的应用程序很重要,您还可以使用工厂函数来捕获闭包中的函数参数:

def make_f(g):
    # Note: a new f() is created each time make_f() is called!
    @jit(nopython=True)
    def f(x):
        return g(x) + g(-x)
    return f
f = make_f(jitted_g_function)
result = f(1)

提高Numba中功能的调度性能是一项持续的任务。

这意味着您可以选择使用函数工厂:

import numpy as np
import numba as nb

def test(a, func=np.median):
    @nb.njit
    def _test(a):
        return func(a)
    return _test(a)

>>> test(np.arange(10))
4.5
>>> test(np.arange(10), np.min)
0
>>> test(np.arange(10), np.mean)
4.5

或者在将函数参数作为参数传递之前将其包装为jitted-function:

import numpy as np
import numba as nb

@nb.njit()
def test(a, f=np.median):
    return f(a)

@nb.njit
def wrapped_mean(a):
    return np.mean(a)

@nb.njit
def wrapped_median(a):
    return np.median(a)

>>> test(np.arange(10))
4.5
>>> test(np.arange(10), wrapped_mean)
4.5
>>> test(np.arange(10), wrapped_median)
4.5

两种选择都有很多样板,并且没有人们希望的那样简单明了。

函数工厂方法还反复创建和编译函数,因此,如果经常使用与参数相同的函数来调用它,则可以使用字典来存储已知的编译函数:

import numpy as np
import numba as nb

_precompiled_funcs = {}

def test(a, func=np.median):
    if func not in _precompiled_funcs:
        @nb.njit
        def _test(arr):
            return func(arr)
        result = _test(a)
        _precompiled_funcs[func] = _test
        return result
    return _precompiled_funcs[func](a)

[另一种方法(使用包装的和抖动的函数)也有一些开销,但是,只要传入的数组具有大量元素(> 1000),它就不会真正引起注意。

如果您显示的功能确实是您要使用的功能,那么我根本不会使用numba。使用Python + NumPy的简单任务无法发挥numba的优势(索引和迭代数组或繁重的运算),它应该更快(或更快)并且更容易调试和理解:

import numba as nb

def test(a, f=np.median):
    return f(a)
© www.soinside.com 2019 - 2024. All rights reserved.