Numba:使用工厂函数与 `cache=True`

问题描述 投票:0回答:1

寻找使用(并提高执行速度)具有一个或多个其他 jitted 函数作为参数的 jitted (numba) 函数,我可以在 numbas 的常见问题解答 中看到以下内容:

dispatching with arguments that are functions has extra overhead. If this matters for your application, you can also use a factory function to capture the function argument in a closure:
def make_f(g):
    # Note: a new f() is created each time make_f() is called!
    @jit(nopython=True)
    def f(x):
        return g(x) + g(-x)
    return f

f = make_f(jitted_g_function)
result = f(1)

拜托,我的问题是:

  • numba 内部会发生什么?
  • 使用
    @njit(cache=True)
    有什么区别?
  • 最终,是否建议同时使用工厂功能和
    cache=True
def make_f(g):
    @njit(cache=True)
    def f(x):
        return g(x) + g(-x)
    return f

为了更深入地了解我的实际用例,它的复杂性略有增加。关闭看起来像这样:

from numba import njit, literal_unroll
from numpy import zeros

def make_f(tup):
    @njit
    def f(x):
        # 'x' is a 2d array
        res = zeros((3,x.shape[1]), dtype=x.dtype)
        slice_start = 0
        third_of_x = int(len(x)/3)
        slice_ends = (third_of_x, third_of_x * 2, len(x))
        for i in range(3):
            slice_end = slice_ends[i]
            for item in literal_unroll(tup):
                cols, func = item
                res[i, cols] = func(x[slice_start:slice_end,cols])
            slice_start = slice_end
        return res
    return f

tup
是列索引和可调用元组的元组。例如:

tup=(
     (np.array([0,2], dtype="int64"), np.sum),
     (np.array([1,3], dtype="int64"), np.max)
    )

我知道在示例中,按行切入 3

x
没有多大意义。 但在我的实际用例中,每个块都是“特定类型”。该块的计算取决于前一个块的类型和当前块的类型。 所以我需要逐行chunk操作,一个接一个。

如果有必要查看“真实”代码,我说的是

jcsagg
项目 中的
oups
函数
(目前不在工厂函数中)。
jcsagg()
然后由非 jitted
cumsegagg
函数
调用。

谢谢你的建议!

python caching numba
1个回答
0
投票

numba 内部会发生什么?

make_f
创建一个函数
f
然后提供给 Numba 以构建一个对象来生成编译函数(但由于惰性编译,该函数不会直接编译)。然后,
make_f
返回 Numba 对象,
f(1)
实际编译基于对象
f
的函数。该函数根据其参数和编译标志进行编译。这里有一个整数参数,因此 Numba 将生成一个带有签名
ReturnType(int_)
的函数,其中
ReturnType
jitted_g_function
之一。编译后,目标函数将具有 one 签名,使用其他参数调用
make_f
不会影响 Numba 对象或目标函数。这可以在重复调用
f
时提高性能,因为 Numba 包装器需要检查输入类型并找到与提供的参数匹配的签名。 拥有更多签名通常会增加匹配时间。我希望这个开销相对较小(除了非常快速的函数调用)。请注意,使用不同的输入类型调用
f
将增加目标对象
f
.

的签名数量

如果您使用 1 个主 Numba 对象进行多种调用(使用通常的方法),那么签名的数量可能会增加,并且 如果函数是为许多不同的输入类型编译的,包装器的开销可能是一个问题

AFAIK,

g
对象被包装在 Numba 对象中,因为
f
是一个闭包,来自父函数(或全局函数)的对象被假定为常量,因此编译器可以生成更快的代码(专门用于目标对象)。使用调度技巧使用户能够使用 N 个不同的
f
函数生成 N 个不同的
g
函数,同时快速调用
g
(由于专业化,这也是昂贵的,因为
f
被独立重新编译 N 次) .

使用

@njit(cache=True)
有什么区别?

cache=True
用于存储 JIT 编译结果,以便稍后再次运行 Python 脚本时不会重新编译,即 解释器的多次执行之间。上面的方法没有提供这样的功能。

最终,是否建议同时使用工厂函数和 cache=True?

没有。至少,并非没有特定需求。调度技巧会显着增加编译时间,

cache=True
也有开销。对于大型应用程序,这种开销可能会变得非常巨大。如果编译时间是一个问题,那么使用 AOT 编译器可能是更好的解决方案(尽管到目前为止它有一些很强的限制,比如 AFAIK 不支持快速数学或并行代码)。在这种情况下,Cython 也可以作为替代方案。当您以迭代方式构建科学应用程序时,
cache=True
通常是一个好主意(您通常不希望每次测试模拟时都重新编译所有函数)。请注意,AOT 编译器不如 JIT 编译器灵活,尤其是在涉及动态操作时(所有内容都需要在应用程序实际运行之前编译)。

© www.soinside.com 2019 - 2024. All rights reserved.