如何并行化/加速令人尴尬的并行 numba 代码?

问题描述 投票:0回答:1

我有以下代码:

@nb.njit(cache=True)
def find_two_largest(arr):
    # Initialize the first and second largest elements
    if arr[0] >= arr[1]:
        largest = arr[0]
        second_largest = arr[1]
    else:
        largest = arr[1]
        second_largest = arr[0]

    # Iterate through the array starting from the third element
    for num in arr[2:]:
        if num > largest:
            second_largest = largest
            largest = num
        elif num > second_largest:
            second_largest = num
    return largest, second_largest


@nb.njit(cache=True)
def max_bar_one(arr):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = np.empty_like(arr)
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest
    return missing_maxes


@nb.njit(cache=True)
def replace_max_row_wise_add_first_delete_last(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in range(0, m - 1):
        result[i + 1, :] = max_bar_one(d[i, :])
    return result


@nb.njit(cache=True)
def main_function(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last(d)
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp

然后我设置数据:

n = 5000
A = np.random.randint(-3, 4, (n, n)).astype(float)
cusum_rows = np.cumsum(A, axis=1)
rowseq = np.arange(n)
d = np.random.randint(-3, 4, (5000, 5000))

然后我们可以用以下方法计时:

%timeit main_function(d, cusum_rows, 0)
166 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

是否可以并行化 for 循环或一般代码以加快速度?我尝试在replace_max_row_wise_add_first_delete_last中使用parallel=True 但它并没有加快代码速度,只报告:

Instruction hoisting:
loop #1:
  Failed to hoist the following:
    dependency: $value_var.73 = getitem(value=_72call__function_11, index=$parfor__index_72.90, fn=<built-in function getitem>)

这令人惊讶,因为 for 循环中的所有调用都是独立的。

这段代码可以加速和/或并行化吗?

python performance numba
1个回答
1
投票

当我在

replace_max_row_wise_add_first_delete_last()
中使用并行化时,速度提高了约 70%:

@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in nb.prange(0, m - 1):                 # <-- using prange here
        result[i + 1, :] = max_bar_one(d[i, :])
    return result


@nb.njit
def main_function_parallel(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last_parallel(d)  # <-- using parallel version of the function here
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp

编辑:额外的加速是删除

missing_maxes = np.empty_like(arr)
临时分配。在这种情况下,加速率为 300%:

@nb.njit
def max_bar_one2(arr, result, to_compare, to_add):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = result
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest

        missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
    return missing_maxes


@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = d[0] + to_add
    for i in nb.prange(0, m - 1):
        max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
    return result


@nb.njit
def main_function_parallel2(d, subcusum, j):
    return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])

基准:

from timeit import timeit

import numba as nb
import numpy as np


@nb.njit(cache=True)
def find_two_largest(arr):
    # Initialize the first and second largest elements
    if arr[0] >= arr[1]:
        largest = arr[0]
        second_largest = arr[1]
    else:
        largest = arr[1]
        second_largest = arr[0]

    # Iterate through the array starting from the third element
    for num in arr[2:]:
        if num > largest:
            second_largest = largest
            largest = num
        elif num > second_largest:
            second_largest = num
    return largest, second_largest


@nb.njit(cache=True)
def max_bar_one(arr):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = np.empty_like(arr)
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest
    return missing_maxes


@nb.njit(cache=True)
def replace_max_row_wise_add_first_delete_last(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in range(0, m - 1):
        result[i + 1, :] = max_bar_one(d[i])
    return result


@nb.njit(cache=True)
def main_function(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last(d)
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp


@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in nb.prange(0, m - 1):
        result[i + 1, :] = max_bar_one(d[i, :])
    return result


@nb.njit
def main_function_parallel(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last_parallel(d)
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp


@nb.njit
def max_bar_one2(arr, result, to_compare, to_add):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = result
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest

        missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
    return missing_maxes


@nb.njit(parallel=True)
def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = d[0] + to_add
    for i in nb.prange(0, m - 1):
        max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
    return result


@nb.njit
def main_function_parallel2(d, subcusum, j):
    return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])


def get_d_cumsum_rows(n):
    A = np.random.randint(-300, 400, (n, n)).astype(float)
    cusum_rows = np.cumsum(A, axis=1)
    d = np.random.randint(-300, 400, (n, n))

    return d, cusum_rows


n = 10
np.random.seed(42)
out1 = main_function(*get_d_cumsum_rows(n), 0)

np.random.seed(42)
out2 = main_function_parallel(*get_d_cumsum_rows(n), 0)

np.random.seed(42)
out3 = main_function_parallel2(*get_d_cumsum_rows(n), 0)

assert np.allclose(out1, out2)
assert np.allclose(out1, out3)

t1 = timeit(
    "main_function(a, b, 0)",
    setup="n=5000;a,b=get_d_cumsum_rows(n)",
    globals=globals(),
    number=100,
)

t2 = timeit(
    "main_function_parallel(a, b, 0)",
    setup="n=5000;a,b=get_d_cumsum_rows(n)",
    globals=globals(),
    number=100,
)

t3 = timeit(
    "main_function_parallel2(a, b, 0)",
    setup="n=5000;a,b=get_d_cumsum_rows(n)",
    globals=globals(),
    number=100,
)

print(t1)
print(t2)
print(t3)

在我的计算机上打印(AMD 5700x):

7.003944834927097
4.12014868715778
2.2788363839499652
© www.soinside.com 2019 - 2024. All rights reserved.