我正在尝试使用
numpy
执行以下操作。由于 aa
的尺寸较大,因此使用 numpy
的速度很慢。我正在尝试使用 numba
来加速它,有一些改进,但我想进一步加速它,因为它是另一个循环的一部分。非常感谢任何建议!
使用
numpy
:
def get_prob(aa):
allmax = aa.max(axis=1)[:, None]
findmax = aa - allmax
mask = ((findmax[:,1,:]==0)&(findmax[:,2,:]==0))
findmax[:, 1, :][mask] = -1
mask = ((findmax[:, 0, :] == 0) & (findmax[:, 1, :] == 0))
findmax[:, 0, :][mask] = -1
mask = ((findmax[:, 0, :] == 0) & (findmax[:, 1, :] == 0) & (findmax[:, 2, :] == 0))
findmax[:, 0, :][mask] = -1
findmax[:, 1, :][mask] = -1
p = np.where(findmax < 0, 0.0, 1.0).transpose(0,2,1)
return p
使用
numba
:
@numba.jit(nopython=True)
def get_prob_nb(aa,num_params,num_action):
p=np.zeros_like(aa)
for i in range(num_params):
for j in range(num_action):
a1 = aa[i, 0, j]
a2 = aa[i, 1, j]
a3 = aa[i, 2, j]
if a1>a2 and a1>a3:
p[i, 0, j] = 1.
elif a2>=a1 and a2>a3:
p[i, 1, j] = 1.
elif a3>=a2 and a3>=a1:
p[i, 2, j] = 1.
p = p.transpose(0, 2, 1)
return p
aa=rng.uniform(0.0, 1.0, 9000000)
aa=aa.reshape(1000,3,3000)
start = time.time()
get_prob_nb(aa, 1000, 3000)
print("elapse", time.time()-start)
def get_prob_keepdims(aa):
max_values = aa.max(axis=1, keepdims=True)
p = np.equal(aa, max_values).astype(float)
return p.transpose(0, 2, 1)
函数
get_prob_keepdims
在计算中使用 keepdims=True
参数,该参数在跨特定轴执行 max
操作后保持原始数组的维度。根据我对所提供代码的理解,我相信这个函数的操作应该与原始 get_prob
函数相同。
在我的 100 次迭代测试中,在 Mac M1 上使用 keepdims 比使用 Numba(没有并行)产生的结果稍快。
有一种令人惊讶的简单方法来并行化您的 numba 调用:
@numba.jit(nopython=True, parallel=True)
def get_prob_nb_parallel(aa, num_params, num_action):
p = np.zeros_like(aa)
for i in numba.prange(num_params):
for j in range(num_action):
a1 = aa[i, 0, j]
a2 = aa[i, 1, j]
a3 = aa[i, 2, j]
if a1 > a2 and a1 > a3:
p[i, 0, j] = 1.
elif a2 >= a1 and a2 > a3:
p[i, 1, j] = 1.
elif a3 >= a2 and a3 >= a1:
p[i, 2, j] = 1.
p = p.transpose(0, 2, 1)
return p
运行测试,我发现与
get_prob_nb
相比,时间缩短了约 30%。