我有以下简单函数,对第二行或数组中的值求和:
@njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
def fast_sum(array_2d, start, end):
s = 0.0
for i in range(start, end):
s += array_2d[1][i]
return s
我计时:
import numpy as np
from numba import njit
A = np.random.rand(2, 500)
%timeit fast_sum(A, 100, 300)
这给了我:
%timeit fast_sum(A, 100, 300)
304 ns ± 17.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
事实上我想在索引处包含值
end
所以我将函数更改为:
@njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
def fast_sum_v2(array_2d, start, end):
s = 0.0
end = end + 1
for i in range(start, end):
s += array_2d[1][i]
return s
现在代码运行速度降低了 40%!
%timeit fast_sum_v2(A, 100, 299)
423 ns ± 6.93 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
我的猜测是,这是因为在 end 上加 1 改变了 end 的类型。但这真的对吗?
首先,我们需要不要忘记
In [1]: import numpy as np
In [2]: i = np.uint64(2**64-1)
In [3]: type(i)
Out[3]: numpy.uint64
In [4]: type(i+1)
Out[4]: numpy.float64
我的猜测是
fastmath
做了一些向下转换以节省一些时间来分配内存以加快性能。
In [1]: import numpy as np
...: from numba import njit
...: A = np.random.rand(2, 500)
In [2]: @njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
...: def fast_sum(array_2d, start, end):
...: s = 0.0
...: for i in range(start, end):
...: s += array_2d[1][i]
...: return s
...:
In [3]: %timeit fast_sum(A, 100, 300)
171 ns ± 1.12 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [4]: @njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
...: def fast_sum_v2(array_2d, start, end):
...: end = end + 1
...: s = 0.0
...: for i in range(start, end):
...: s += array_2d[1][i]
...: return s
...:
In [5]: %timeit fast_sum_v2(A, 100, 299)
228 ns ± 1.85 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [6]: @njit('float64(float64[:, ::1], uint32, uint32)', fastmath=True)
...: def fast_sum_v2_with32(array_2d, start, end):
...: end = np.uint32(end + 1)
...: s = 0.0
...: for i in range(start, end):
...: s += array_2d[1][i]
...: return s
...:
In [7]: %timeit fast_sum_v2_with32(A, 100, 299)
172 ns ± 0.687 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [8]: ## Here as you can see we have same time with fast_sum
In [9]: ## To fix v2 function we will advice about type of end variable for fastmath
In [10]: @njit('float64(float64[:, ::1], uint64, uint64)', fastmath=True)
...: def fast_sum_v2_fixed(array_2d, start, end):
...: end = np.uint64(end + 1)
...: s = 0.0
...: for i in range(start, end):
...: s += array_2d[1][i]
...: return s
...:
In [11]: %timeit fast_sum_v2_fixed(A, 100, 299)
173 ns ± 1.98 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
一些调查和陷阱取自这里:Numba 快速数学不会提高速度
顺便说一句:如果您知道参数的大小,它可能会更快
In [12]: @njit('float64(float64[:, ::1], uint8, uint8)', fastmath=False)
...: def fast_sum_v3(array_2d, start, end):
...: end = np.uint8(end + 1)
...: s = 0.0
...: for i in range(start, end):
...: s += array_2d[1][i]
...: return s
...:
In [13]: %timeit fast_sum_v3(A, 100, 299)
135 ns ± 0.737 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)