github上的用户使用numba no python模式报告了以下代码的错误:
from numba import njit
import numpy as np
@njit
def foo():
a = np.ones(1, np.bool_)
if a > 0:
print('truebr')
else:
print('falsebr')
foo()
他被告知表达式a > 0
不是谓词,而是有条件的。为了解决这个问题,他将“用条件包裹条件以创建谓词”。
这是否意味着(a > 0) == True
将修复numba或其他内容中出现的错误。
https://github.com/numba/numba/pull/3901/commits/598cdd1707fdeb11b8f1d70aef2d3e36ef37bd34。这是numba中这些类型的错误的解决方法吗?
In [412]: def foo():
...: a = np.ones(1, np.bool_)
...: if a > 0:
...: print('truebr')
...: else:
...: print('falsebr')
...:
In [413]: foo()
truebr
但是如果a
是具有更多值的数组:
In [414]: def foo():
...: a = np.ones(2, np.bool_)
...: if a > 0:
...: print('truebr')
...: else:
...: print('falsebr')
...:
In [415]: foo()
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
如果尝试在njit
中使用您的功能,则会得到很长的追溯;太长,无法显示或分析,但是从本质上讲,它无法在njit
模式下完成。鉴于上述价值错误,我并不感到惊讶。 njit
不允许使用“仅一个”真值数组。
通常,使用numba
时应进行迭代。这是它的主要目的-运行numpy/python
问题,否则这些问题迭代起来将非常昂贵。如果我更改功能以测试
a
的每个元素,它将起作用:
In [421]: @numba.njit ...: def foo(): ...: a = np.array([True]) ...: for i in a: ...: if i > 0: ...: print('truebr') ...: else: ...: print('falsebr') ...: In [422]: foo() truebr