当我尝试访问类型化的 jitted 函数列表中的元素时,我遇到了 Numba 的问题。我收到的错误消息是“LLVM IR 解析错误”,其中包含行
ret i8* null
。
这是重现该问题的简化代码片段:
import numba
from numba import njit, types
import numpy as np
@njit
def f1(vars, opts):
x, y = vars
opt1 = opts[0]
return (x + y) * opt1
@njit
def f2(vars, opts):
x, y = vars
opt1, opt2 = opts
return (2 * x + y) * opt1 * opt2
# Define the function type
func_type = types.float64(types.float64[:], types.float64[:]).as_type()
# Create a typed list to hold the function pointers
f_list = numba.typed.List.empty_list(func_type)
f_list.append(f1)
f_list.append(f2)
@njit
def dump(l):
for i, f in enumerate(l):
print(i, f)
dump(f_list)
当我运行
dump
函数时,出现以下错误:
LLVM IR parsing error
<string>:457:7: error: value doesn't match function result type 'i32'
ret i8* null
同时在python模式下
for i, f in enumerate(f_list):
print(i,f)
它按预期工作。
我已经将 Numba 更新到最新版本,但问题仍然存在。
Python 3.9.12
llvmlite-0.40.1
numba-0.57.1
任何有关可能导致此问题的原因以及如何解决此问题的见解将不胜感激。
您可以通过打印函数类型的字符串表示来解决此问题(例如,不要将函数直接传递给 python 函数
print()
,而是先将其转换为字符串):
import numba
import numpy as np
from numba import njit, types
@njit
def f1(vars, opts):
x, y = vars
opt1 = opts[0]
return (x + y) * opt1
@njit
def f2(vars, opts):
x, y = vars
opt1, opt2 = opts
return (2 * x + y) * opt1 * opt2
# Define the function type
func_type = types.float64(types.float64[:], types.float64[:]).as_type()
# Create a typed list to hold the function pointers
f_list = numba.typed.List.empty_list(func_type)
f_list.append(f1)
f_list.append(f2)
@njit
def dump(l):
a = np.array([1, 2], dtype=np.float64)
b = np.array([3, 4], dtype=np.float64)
for i, f in enumerate(l):
print(i, f"{f}")
print(f(a, b))
dump(f_list)
打印:
0 <object type:FunctionType[float64(array(float64, 1d, A), array(float64, 1d, A))]>
9.0
1 <object type:FunctionType[float64(array(float64, 1d, A), array(float64, 1d, A))]>
48.0