如何加速 Numpy

问题描述 投票:0回答:2

我正在尝试使用

numpy
执行以下操作。由于
aa
的尺寸较大,因此使用
numpy
的速度很慢。我正在尝试使用
numba
来加速它,有一些改进,但我想进一步加速它,因为它是另一个循环的一部分。非常感谢任何建议!

使用

numpy

def get_prob(aa):
    allmax = aa.max(axis=1)[:, None]
    findmax = aa - allmax
    mask = ((findmax[:,1,:]==0)&(findmax[:,2,:]==0))
    findmax[:, 1, :][mask] = -1

    mask = ((findmax[:, 0, :] == 0) & (findmax[:, 1, :] == 0))
    findmax[:, 0, :][mask] = -1

    mask = ((findmax[:, 0, :] == 0) & (findmax[:, 1, :] == 0) & (findmax[:, 2, :] == 0))
    findmax[:, 0, :][mask] = -1
    findmax[:, 1, :][mask] = -1

    p = np.where(findmax < 0, 0.0, 1.0).transpose(0,2,1)
    return p

使用

numba

@numba.jit(nopython=True)
def get_prob_nb(aa,num_params,num_action):
    p=np.zeros_like(aa)

    for i in range(num_params):
        for j in range(num_action):
            a1 = aa[i, 0, j]
            a2 = aa[i, 1, j]
            a3 = aa[i, 2, j]
            if a1>a2 and a1>a3:
                p[i, 0, j] = 1.
            elif a2>=a1 and a2>a3:
                p[i, 1, j] = 1.
            elif a3>=a2 and a3>=a1:
                p[i, 2, j] = 1.

    p = p.transpose(0, 2, 1)
    return p

aa=rng.uniform(0.0, 1.0, 9000000)
aa=aa.reshape(1000,3,3000)
start = time.time()
get_prob_nb(aa, 1000, 3000)
print("elapse", time.time()-start)
python numpy numba
2个回答
0
投票
def get_prob_keepdims(aa):
    max_values = aa.max(axis=1, keepdims=True)
    p = np.equal(aa, max_values).astype(float)
    return p.transpose(0, 2, 1)

函数

get_prob_keepdims
在计算中使用
keepdims=True
参数,该参数在跨特定轴执行
max
操作后保持原始数组的维度。根据我对所提供代码的理解,我相信这个函数的操作应该与原始
get_prob
函数相同。

在我的 100 次迭代测试中,在 Mac M1 上使用 keepdims 比使用 Numba(没有并行)产生的结果稍快。


-1
投票

有一种令人惊讶的简单方法来并行化您的 numba 调用:

@numba.jit(nopython=True, parallel=True)
def get_prob_nb_parallel(aa, num_params, num_action):
    p = np.zeros_like(aa)

    for i in numba.prange(num_params):
        for j in range(num_action):
            a1 = aa[i, 0, j]
            a2 = aa[i, 1, j]
            a3 = aa[i, 2, j]
            if a1 > a2 and a1 > a3:
                p[i, 0, j] = 1.
            elif a2 >= a1 and a2 > a3:
                p[i, 1, j] = 1.
            elif a3 >= a2 and a3 >= a1:
                p[i, 2, j] = 1.

    p = p.transpose(0, 2, 1)
    return p

运行测试,我发现与

get_prob_nb
相比,时间缩短了约 30%。

© www.soinside.com 2019 - 2024. All rights reserved.