使用列表时为什么用numba计算总和?

问题描述 投票:1回答:1

这是我的代码:

@numba.jit( )
def dis4(x1,x2):
    s=0.0
    for i in range(len(x1)):
        s+=(x1[i]-x2[i])**2
    return math.sqrt(s)
x1=[random.random() for _ in range(m)]
x2=[random.random() for _ in range(m)]
%timeit dis4(x1,x2)

每循环3.32 ms±37.8μs(平均值±标准偏差,7次运行,每次100次循环)

相比之下,没有jit它会更快。

每回路137μs±1.62μs(平均值±标准偏差,7次运行,每次10000次循环)

python list performance jit numba
1个回答
1
投票

它的速度较慢,因为numba(默默地)复制了列表。

要了解为什么会发生这种情况,您需要知道numba具有对象模式和nopython模式。在对象模式下,它可以在Python数据结构上运行,但是它不会比普通的Python函数快得多,甚至更慢(至少在一般情况下,有非常罕见的例外)。在nopython模式下,numba无法在像list这样的Python数据结构上运行,因此为了使lists工作,它必须使用非Python列表。要从Python列表创建这样的非Python列表(它被称为反射列表),它必须复制和转换列表内容。

这种复制和转换使你的情况变得更慢。

这也是为什么人们通常应该避免使用非数组参数或使用numba函数返回的原因。数组的内容不需要转换,至少如果numba支持数组的dtype,那么这些是“安全的”。

如果这些数据结构(列表,元组,集合)被限制在numba中它们就可以了 - 但是当它们跨越numba⭤Python边界时,必须复制它们(几乎)总是使所有性能增益无效。


只是为了展示函数如何使用数组执行:

import math
import random
import numba as nb
import numpy as np

def dis4_plain(x1,x2):
    s=0.0
    for i in range(len(x1)):
        s+=(x1[i]-x2[i])**2
    return math.sqrt(s)

@nb.jit
def dis4(x1,x2):
    s=0.0
    for i in range(len(x1)):
        s+=(x1[i]-x2[i])**2
    return math.sqrt(s)

m = 10_000
x1 = [random.random() for _ in range(m)]
x2 = [random.random() for _ in range(m)]
a1 = np.array(x1)
a2 = np.array(x2)

timing:

dis4(x1, x2)
dis4(a1, a2)

%timeit dis4_plain(x1, x2)
# 2.71 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit dis4(x1, x2)
# 24.1 ms ± 279 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit dis4(a1, a2)
# 14 µs ± 608 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

因此,虽然列表和numba.jit的速度慢了10倍,但使用数组的jitted函数几乎比带有列表的Python函数快200倍。

© www.soinside.com 2019 - 2024. All rights reserved.