我正在学习
numba
,现在使用官方文档中的教程:https://numba.readthedocs.io/en/stable/user/performance-tips.html#performance-tips
设置
fastmath=True
有时应该会提高性能,如教程中所示,但每当我使用 fastmath=True
时,它总是比我设置 fastmath=False
时慢。例如,在教程的parallel=True
部分,它显示:
函数名称 | 执行时间 |
---|---|
|
9.81 毫秒 |
|
5.37 毫秒 |
在我的机器上却是:
函数名称 | 执行时间 |
---|---|
|
4.67 毫秒 |
|
4.96 毫秒 |
我在教程中的其他基准测试中观察到类似的结果,每当我使用
fastmath=True
时,速度都会变慢。这是为什么?
以下是教程中我用来获取上述数字的两个函数:
@njit(parallel=True)
def do_sum_parallel(A):
# each thread can accumulate its own partial sum, and then a cross
# thread reduction is performed to obtain the result to return
n = len(A)
acc = 0.
for i in prange(n):
acc += np.sqrt(A[i])
return acc
@njit(parallel=True, fastmath=True)
def do_sum_parallel_fast(A):
n = len(A)
acc = 0.
for i in prange(n):
acc += np.sqrt(A[i])
return acc
fastmath
Numba 中的优化主要有利于浮点运算,我想也许您正在使用整数来测试:
from numba import njit, prange
import numpy as np
import timeit
@njit(parallel=True)
def do_sum_parallel(A):
# each thread can accumulate its own partial sum, and then a cross
# thread reduction is performed to obtain the result to return
n = len(A)
acc = 0.
for i in prange(n):
acc += np.sqrt(A[i])
return acc
@njit(parallel=True, fastmath=True)
def do_sum_parallel_fast(A):
n = len(A)
acc = 0.
for i in prange(n):
acc += np.sqrt(A[i])
return acc
float_data = np.random.rand(10000)
float_time_do_sum_parallel = timeit.timeit(lambda: do_sum_parallel(float_data), number=100)
float_time_do_sum_parallel_fast = timeit.timeit(lambda: do_sum_parallel_fast(float_data), number=100)
print(f'{float_time_do_sum_parallel = :.2f} ms')
print(f'{float_time_do_sum_parallel_fast = :.2f} ms')
int_data = np.random.randint(0, 100, size=10000)
int_time_do_sum_parallel = timeit.timeit(lambda: do_sum_parallel(int_data), number=100)
int_time_do_sum_parallel_fast = timeit.timeit(lambda: do_sum_parallel_fast(int_data), number=100)
print(f'{int_time_do_sum_parallel = :.2f} ms')
print(f'{int_time_do_sum_parallel_fast = :.2f} ms')
使用repl.it输出:
float_time_do_sum_parallel = 13.79 ms
float_time_do_sum_parallel_fast = 10.49 ms
int_time_do_sum_parallel = 9.97 ms
int_time_do_sum_parallel_fast = 10.41 ms