为什么微小的变化会对我的 numba 并行函数的运行时间产生巨大影响?

问题描述 投票:0回答:0

我试图理解为什么我的并行化 numba 函数会按照它的方式运行。特别是,为什么它对数组的使用方式如此敏感。

我有以下功能:

@njit(parallel=True)
def f(n):
    g = lambda i,j: zeros(3) + sqrt(i*j)
    x = zeros((n,3))
    for i in prange(n):
        for j in range(n):
            tmp      = g(i,j)
            x[i] += tmp
    return x

相信 n 足够大,并行计算才有用。出于某种原因,这实际上用更少的内核运行得更快。现在,当我做一个小改动时 (

x[i]
->
x[i, :]
).

@njit(parallel=True)
def f(n):
    g = lambda i,j: zeros(3) + sqrt(i*j)
    x = zeros((n,3))
    for i in prange(n):
        for j in range(n):
            tmp      = g(i,j)
            x[i, :] += tmp
    return x

性能明显更好,并且可以随着内核数量适当扩展(即更多内核更快)。为什么切片会使性能更好?更进一步,另一个有很大不同的变化是将

lambda
函数转换为外部 njit 函数。

@njit
def g(i,j):
    x = zeros(3) + sqrt(i*j)
    return x

@njit(parallel=True)
def f(n):
    x = zeros((n,3))
    for i in prange(n):
        for j in range(n):
            tmp      = g(i,j)
            x[i, :] += tmp
    return x

这再次破坏了性能和缩放,恢复到等于或低于第一种情况的运行时间。为什么这个外部函数会破坏性能?可以使用下面显示的两个选项恢复性能。

@njit
def g(i,j):
    x = sqrt(i*j)
    return x

@njit(parallel=True)
def f(n):
    x = zeros((n,3))
    for i in prange(n):
        for j in range(n):
            tmp      = zeros(3) + g(i,j)
            x[i, :] += tmp
    return x
@njit(parallel=True)
def f(n):
    def g(i,j):
        x = zeros(3) + sqrt(i*j)
        return x
    x = zeros((n,3))
    for i in prange(n):
        for j in range(n):
            tmp      = g(i,j)
            x[i, :] += tmp
    return x

为什么

parallel=True
numba 装饰函数对数组的使用方式如此敏感?我知道数组不是简单的可并行化的,但这些变化中的每一个都会显着影响性能的确切原因对我来说并不明显。

python numba
© www.soinside.com 2019 - 2024. All rights reserved.