Python numba fastmath 优化

问题描述 投票:0回答:1

我正在学习

numba
,现在使用官方文档中的教程:https://numba.readthedocs.io/en/stable/user/performance-tips.html#performance-tips

设置

fastmath=True
有时应该会提高性能,如教程中所示,但每当我使用
fastmath=True
时,它总是比我设置
fastmath=False
时慢。例如,在教程的
parallel=True
部分,它显示:

函数名称 执行时间
do_sum_parallel
9.81 毫秒
do_sum_parallel_fast
5.37 毫秒

在我的机器上却是:

函数名称 执行时间
do_sum_parallel
4.67 毫秒
do_sum_parallel_fast
4.96 毫秒

我在教程中的其他基准测试中观察到类似的结果,每当我使用

fastmath=True
时,速度都会变慢。这是为什么?

以下是教程中我用来获取上述数字的两个函数:

@njit(parallel=True)
def do_sum_parallel(A):
    # each thread can accumulate its own partial sum, and then a cross
    # thread reduction is performed to obtain the result to return
    n = len(A)
    acc = 0.
    for i in prange(n):
        acc += np.sqrt(A[i])
    return acc

@njit(parallel=True, fastmath=True)
def do_sum_parallel_fast(A):
    n = len(A)
    acc = 0.
    for i in prange(n):
        acc += np.sqrt(A[i])
    return acc
python numba
1个回答
0
投票

fastmath
Numba 中的优化主要有利于浮点运算,我想也许您正在使用整数来测试:

from numba import njit, prange
import numpy as np
import timeit


@njit(parallel=True)
def do_sum_parallel(A):
  # each thread can accumulate its own partial sum, and then a cross
  # thread reduction is performed to obtain the result to return
  n = len(A)
  acc = 0.
  for i in prange(n):
    acc += np.sqrt(A[i])
  return acc


@njit(parallel=True, fastmath=True)
def do_sum_parallel_fast(A):
  n = len(A)
  acc = 0.
  for i in prange(n):
    acc += np.sqrt(A[i])
  return acc


float_data = np.random.rand(10000)
float_time_do_sum_parallel = timeit.timeit(lambda: do_sum_parallel(float_data), number=100)
float_time_do_sum_parallel_fast = timeit.timeit(lambda: do_sum_parallel_fast(float_data), number=100)
print(f'{float_time_do_sum_parallel = :.2f} ms')
print(f'{float_time_do_sum_parallel_fast = :.2f} ms')

int_data = np.random.randint(0, 100, size=10000)
int_time_do_sum_parallel = timeit.timeit(lambda: do_sum_parallel(int_data), number=100)
int_time_do_sum_parallel_fast = timeit.timeit(lambda: do_sum_parallel_fast(int_data), number=100)
print(f'{int_time_do_sum_parallel = :.2f} ms')
print(f'{int_time_do_sum_parallel_fast = :.2f} ms')

使用repl.it输出:

float_time_do_sum_parallel = 13.79 ms
float_time_do_sum_parallel_fast = 10.49 ms
int_time_do_sum_parallel = 9.97 ms
int_time_do_sum_parallel_fast = 10.41 ms
© www.soinside.com 2019 - 2024. All rights reserved.