为什么依赖numba jitt'ed函数的排序很重要?

问题描述 投票:3回答:1

在python中,您可以定义多个以任何顺序相互调用的函数,并且在运行时将调用函数。一旦存在,这些函数在脚本中定义的顺序无关紧要。例如,以下内容有效且有效

import numpy as np

def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])  # calling func2 here which is defined below
    return out

def func2(a):
    out = a + 1
    return out

func1可以调用func2,尽管func2是在func1之后定义的。

但是,如果我用numba装饰这些函数,我会收到错误

import numpy as np
import numba as nb


@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])
    return out

@nb.jit("f8(f8)", nopython=True)
def func2(a):
    out = a + 1
    return out

>>> TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    Untyped global name 'func2': cannot determine Numba type of <class 
    'numba.ir.UndefinedType'>

所以当使用JIT编译func2时,numba不知道func1是什么。简单地切换这些功能的顺序是有效的,所以func2func1之前出现

@nb.jit("f8(f8)", nopython=True)
def func2(a):
    out = a + 1
    return out

@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])
    return out

为什么是这样?我感觉纯python模式有效,因为python是动态类型而不是编译的,而numba,使用JIT,按定义编译函数(因此可能需要完全了解每个函数中发生的所有事情?)。但我不明白为什么numba不会在所有函数的范围内搜索,如果遇到它没有看到的函数。

python jit numba
1个回答
3
投票

简短版本 - 删除"f8[:](f8[:])"

你的直觉是正确的。在调用时查找Python函数,这就是为什么它们可以不按顺序定义的原因。使用dis(反汇编)模块查看python字节码可以清楚地看到 - 每次调用函数b时,名称a都会被视为全局。

def a():
    return b()

def b():
    return 2

import dis
dis.dis(a)
#  2           0 LOAD_GLOBAL              0 (b)
#              2 CALL_FUNCTION            0
#              4 RETURN_VALUE

在nopython模式下,numba需要静态地知道被调用的每个函数的地址 - 这使得代码快速(不再进行运行时查找),并且还打开了其他优化的大门,如内联。

也就是说,numba可以处理这种情况。通过指定类型签名("f8[:](f8[:])"),您可以提前编译。省略它,一个数字将推迟到第一个函数调用它,它将工作。

© www.soinside.com 2019 - 2024. All rights reserved.