Numba - 是否有可能在没有 nan 的情况下更快地进行绝对求和?

问题描述 投票:0回答:1

我有一个来自 numba 官方文档的稍微修改的示例,如下所示:

from numba import njit
import numpy as np

@njit
def do_sum(A, lb, ub):
    n = len(A)
    acc = 0.0
    for i in range(n):
        a = 0.0 if np.isnan(A[i]) else A[i]
        acc += abs(max(min(a, ub[i]), lb[i]))
    return acc

基本上,任务是对一个 numpy 数组进行绝对有界求和,它可能有

np.nan
A
lb
ub
都是具有相同长度的
1d
数组。

由于

A
可能包含
np.nan
,看来我不能使用
@njit(fastmath=True)
。然而,当
A
不包含
np.nan
时,我的基准测试结果表明,使用
@njit(fastmath=True)
明显快于
@njit
.

我的问题是,两者之间是否有一些甜蜜点,这样我就可以使代码与

fastmath
一起工作,并在上面的
do_sum
实现上获得加速?

或者实际上,任何可以使

do_sum
更快的方法/方法都将非常受欢迎。

为简单起见,我们可以假设

lb
ub
不包含
nan
值。但如果解决方案也能处理它们,那就更好了。

(顺便说一下,此代码示例与官方 numba 文档中说明

parallel=True
的示例非常相似,但是当我尝试在上面的
parallel=True
中添加
do_sum
时,我的速度明显变慢了。不知道为什么情况。)

感谢您的帮助。

python numpy numba
1个回答
0
投票

您可以使用

np.nan_to_num

do_sum(np.nan_to_num(A), lb, ub)

或者使用向量化操作:

acc = np.sum(np.abs(np.maximum(np.minimum(np.nan_to_num(A), ub), lb)))
© www.soinside.com 2019 - 2024. All rights reserved.