是什么原因导致此 numba 代码变慢?

问题描述 投票:0回答:1

我有以下简单函数,对第二行或数组中的值求和:

@njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
def fast_sum(array_2d, start, end):
    s = 0.0
    for i in range(start, end):
        s += array_2d[1][i]
    return s

我计时:

import numpy as np
from numba import njit
A = np.random.rand(2, 500)
%timeit fast_sum(A, 100, 300)

这给了我:

%timeit fast_sum(A, 100, 300)
304 ns ± 17.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

事实上我想在索引处包含值

end
所以我将函数更改为:

@njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
def fast_sum_v2(array_2d, start, end):
    s = 0.0
    end = end + 1
    for i in range(start, end):
        s += array_2d[1][i]
    return s

现在代码运行速度降低了 40%!

%timeit fast_sum_v2(A, 100, 299)
423 ns ± 6.93 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

我的猜测是,这是因为在 end 上加 1 改变了 end 的类型。但这真的对吗?

python numba
1个回答
0
投票

首先,我们需要不要忘记

In [1]: import numpy as np

In [2]: i = np.uint64(2**64-1)

In [3]: type(i)
Out[3]: numpy.uint64

In [4]: type(i+1)
Out[4]: numpy.float64

我的猜测是

fastmath
做了一些向下转换以节省一些时间来分配内存以加快性能。

In [1]: import numpy as np
   ...: from numba import njit
   ...: A = np.random.rand(2, 500)

In [2]: @njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
   ...: def fast_sum(array_2d, start, end):
   ...:     s = 0.0
   ...:     for i in range(start, end):
   ...:         s += array_2d[1][i]
   ...:     return s
   ...: 

In [3]: %timeit fast_sum(A, 100, 300)
171 ns ± 1.12 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [4]: @njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
   ...: def fast_sum_v2(array_2d, start, end):
   ...:     end = end + 1
   ...:     s = 0.0
   ...:     for i in range(start, end):
   ...:         s += array_2d[1][i]
   ...:     return s
   ...: 

In [5]: %timeit fast_sum_v2(A, 100, 299)
228 ns ± 1.85 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [6]: @njit('float64(float64[:, ::1], uint32, uint32)', fastmath=True)
   ...: def fast_sum_v2_with32(array_2d, start, end):
   ...:     end = np.uint32(end + 1)
   ...:     s = 0.0
   ...:     for i in range(start, end):
   ...:         s += array_2d[1][i]
   ...:     return s
   ...: 

In [7]: %timeit fast_sum_v2_with32(A, 100, 299)
172 ns ± 0.687 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [8]: ## Here as you can see we have same time with fast_sum

In [9]: ## To fix v2 function we will advice about type of end variable for fastmath

In [10]: @njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
    ...: def fast_sum_v2_fixed(array_2d, start, end):
    ...:     end = np.uint64(end + 1)
    ...:     s = 0.0
    ...:     for i in range(start, end):
    ...:         s += array_2d[1][i]
    ...:     return s
    ...: 

In [11]: %timeit fast_sum_v2_fixed(A, 100, 299)
173 ns ± 1.98 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

一些调查和陷阱取自这里:Numba 快速数学不会提高速度

顺便说一句:如果您知道参数的大小,它可能会更快

In [12]: @njit('float64(float64[:, ::1], uint8, uint8)', fastmath=False)
    ...: def fast_sum_v3(array_2d, start, end):
    ...:     end = np.uint8(end + 1)
    ...:     s = 0.0
    ...:     for i in range(start, end):
    ...:         s += array_2d[1][i]
    ...:     return s
    ...: 

In [13]: %timeit fast_sum_v3(A, 100, 299)
135 ns ± 0.737 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.