如何将包含函数的数组作为参数传递给njit?

问题描述 投票:0回答:1

我想将一个包含 numba 编译函数列表的数组作为参数传递到

njit
方法中。在尝试这样做时,我遇到了以下错误:
non-precise type array(pyobject, 1d, C)

我可以将 numba 编译函数作为参数传递到

njit
方法中,但如果 numba 编译函数存储在数组中,我就无法这样做。

有没有办法将函数数组作为参数传递给 njit 函数?或者类似类型转换数组的东西,这样 numba 就知道数组的内容

np.array([function1, function2], dtype=function)

这是我的代码示例

import numpy as np
from numba import njit

@njit()
def function1(x, y):
    return x > y

@njit()
def function2(x, y):
    return x < y

@njit()
def main(inputArray):
    print(inputArray[0](1,2))
    print(inputArray[1](1,2))

functionArray = np.array([function1, function2])
main(functionArray)

错误信息

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type array(pyobject, 1d, C)
python numba
1个回答
0
投票

您可以尝试将可变数量的参数传递给函数(但您仍然会收到警告消息):

import numba

T = numba.float64(
    numba.float64,
    numba.float64,
)


@numba.njit(T)
def function1(x, y):
    return x > y


@numba.njit(T)
def function2(x, y):
    return x < y


@numba.njit
def main(*inputArray):
    print(inputArray[0](1.0, 2.0))
    print(inputArray[1](1.0, 2.0))


l = [function1, function2]
main(*l)

打印:

0.0
1.0
© www.soinside.com 2019 - 2024. All rights reserved.