我有一个来自 numba 官方文档的稍微修改的示例,如下所示:
from numba import njit
import numpy as np
@njit
def do_sum(A, lb, ub):
n = len(A)
acc = 0.0
for i in range(n):
a = 0.0 if np.isnan(A[i]) else A[i]
acc += abs(max(min(a, ub[i]), lb[i]))
return acc
基本上,任务是对一个 numpy 数组进行绝对有界求和,它可能有
np.nan
。 A
、lb
和ub
都是具有相同长度的1d
数组。
由于
A
可能包含np.nan
,看来我不能使用@njit(fastmath=True)
。然而,当 A
不包含 np.nan
时,我的基准测试结果表明,使用 @njit(fastmath=True)
明显快于 @njit
.
我的问题是,两者之间是否有一些甜蜜点,这样我就可以使代码与
fastmath
一起工作,并在上面的 do_sum
实现上获得加速?
或者实际上,任何可以使
do_sum
更快的方法/方法都将非常受欢迎。
为简单起见,我们可以假设
lb
和 ub
不包含 nan
值。但如果解决方案也能处理它们,那就更好了。
(顺便说一下,此代码示例与官方 numba 文档中说明
parallel=True
的示例非常相似,但是当我尝试在上面的 parallel=True
中添加 do_sum
时,我的速度明显变慢了。不知道为什么情况。)
感谢您的帮助。
您可以使用
np.nan_to_num
:
do_sum(np.nan_to_num(A), lb, ub)
或者使用向量化操作:
acc = np.sum(np.abs(np.maximum(np.minimum(np.nan_to_num(A), ub), lb)))