这里是玩具njit
函数,它接收距离矩阵,循环遍历矩阵的每一行,并在每一列中记录最小值,以及该最小值来自哪一行。但是,IIUC使用prange可能会导致竞争状态(尤其是对于较大的输入数组):
from numba import njit, prange
import numpy as np
@njit
def some_transformation_func(D, row_i):
"""
This function applies some transformation to the ith row (`row_i`) in the `D` matrix in place.
However, the transformation time is random (but all less than a second), which means
that the rows can take
"""
# Apply some inplace transformation on the ith row of D
@njit(parallel=True)
def some_func(D):
P = np.empty((D.shape[1]))
I = np.empty((D.shape[1]), np.int64)
P[:] = np.inf
I[:] = -1
for row in prange(D.shape[0]):
some_transformation_func(D, row)
for col in range(D.shape[1]):
if P[col] > D[row, col]:
P[col] = D[row, col]
I[col] = row
return P, I
if __name__ == "__main__":
D = np.array([[4,1,6,9,9],
[1,3,8,2,7],
[2,8,0,0,1],
[3,7,4,6,5]
])
P, I = some_func(D)
print(P)
print(I)
# [1. 1. 0. 0. 1.]
# [1 0 2 2 2]
我将如何确认是否存在竞争条件(特别是如果D非常大且行和列更多,则如何?)而且,更重要的是,如果有比赛条件,我该如何避免呢?
在这种情况下,最好的方法是将数据手动分块为prange
个块,然后相应地分配处理,最后执行一个操作,而不是将n_threads
设置为数组的大小。减少。因此,如下所示:
from numba import njit, prange, config
import numpy as np
@njit
def wrapper_func(thread_idx, start_indices, stop_indices, D, P, I):
for row in range(start_indices[thread_idx], stop_indices[thread_idx]):
some_transformation_func(D, row)
for col in range(D.shape[1]):
if P[thread_idx, col] > D[row, col]:
P[thread_idx, col] = D[row, col]
I[thread_idx, col] = row
@njit
def some_transformation_func(D, row_i):
"""
This function applies some transformation to the ith row (`row_i`) in the `D` matrix in place.
However, the transformation time is random (but all less than a second), which means
that the rows can take
"""
# Apply some inplace transformation on the ith row of D
@njit(parallel=True)
def some_func(D):
n_threads = config.NUMBA_NUM_THREADS # Let's assume that there are 2 threads
P = np.empty((n_threads, D.shape[1]))
I = np.empty((n_threads, D.shape[1]), np.int64)
P[:, :] = np.inf
I[:, :] = -1
start_indices = np.array([0, 2], np.int64)
stop_indices = np.array([2, 4], np.int64) # Note that these are exclusive
for thread_idx in prange(n_threads):
wrapper_func(thread_idx, start_indices, stop_indices, D, P, I)
# Perform reduction from all threads and store results in P[0]
for thread_idx in range(1, n_threads):
for i in prange(l):
if P[0, i] > P[thread_idx, i]:
P[0, i] = P[thread_idx, i]
I[0, i] = I[thread_idx, i]
return P[0], I[0]
if __name__ == "__main__":
D = np.array([[4,1,6,9,9],
[1,3,8,2,7],
[2,8,0,0,1],
[3,7,4,6,5]
])
P, I = some_func(D)
print(P)
print(I)
# [1. 1. 0. 0. 1.]
# [1 0 2 2 2]
请注意,这将花费您更多的内存(恰好n_threads
更多的内存),但是您将从并行化中受益。此外,代码变得更整洁,更易于维护。一个人需要做的就是找出最好的方法来对数据进行分块,并确定start_row
和stop_row
(不包括)索引。