Python 中嵌套循环的替代方案

问题描述 投票:0回答:1

我编写了一个函数 U_p_law,它采用 2 个玩家的 2 个概率密度函数(L_P 和 L_Q)和 2 个整数值,这些整数值定义每个玩家互相比赛时的频率。

目前的计算是通过 2 个嵌套循环完成的,每个循环都通过 omega 进行索引(一个从 0 到 3500 的数组,步骤 10)。该函数返回第一个玩家的标准化效用函数:

def U_p_law(W,L,L_P, L_Q):
    omega = np.arange(0, 3501, 10)

    U_p = np.zeros_like(omega, dtype=float)

    for p_idx, p in enumerate(omega):
        for q_idx, q in enumerate(omega):
            U_p[p_idx] += (probability_of_loss(q - p)**W * probability_of_loss(p - q)**L * L_Q[q_idx] * L_P[p_idx])

    normalization_factor = np.sum(U_p)
    U_p /= normalization_factor

    return omega, U_p

omega, U_p = U_p_law(W, L, L_P, L_Q)

由于这些嵌套循环计算起来很耗时,我想知道是否有更好的方法来编写这个函数?

嵌套循环方法给了我想要的结果,但需要计算时间。 我将这些值传递给函数:

P_mean = 1500
P_std = 100
Q_mean = 1500
Q_std = 100
W = 1  # Number of matches won by P
L = 0  # Number of matches lost by P
L_P = np.exp(-0.5 * ((np.arange(0, 3501, 10) - P_mean) / P_std)**2) / (P_std * np.sqrt(2 * np.pi))
L_Q = np.exp(-0.5 * ((np.arange(0, 3501, 10) - Q_mean) / Q_std)**2) / (Q_std * np.sqrt(2 * np.pi))

def probability_of_loss(x):
    return 1 / (1 + np.exp(x / 67))

提前致谢

python loops nested distribution
1个回答
0
投票

您可以使用 加速该功能。这是一个示例+简单的基准测试:

from timeit import timeit

from numba import njit

P_mean = 1500
P_std = 100
Q_mean = 1500
Q_std = 100
W = 1  # Number of matches won by P
L = 0  # Number of matches lost by P
L_P = np.exp(-0.5 * ((np.arange(0, 3501, 10) - P_mean) / P_std) ** 2) / (
    P_std * np.sqrt(2 * np.pi)
)
L_Q = np.exp(-0.5 * ((np.arange(0, 3501, 10) - Q_mean) / Q_std) ** 2) / (
    Q_std * np.sqrt(2 * np.pi)
)


def probability_of_loss(x):
    return 1 / (1 + np.exp(x / 67))


def U_p_law(W, L, L_P, L_Q):
    omega = np.arange(0, 3501, 10)

    U_p = np.zeros_like(omega, dtype=float)

    for p_idx, p in enumerate(omega):
        for q_idx, q in enumerate(omega):
            U_p[p_idx] += (
                probability_of_loss(q - p) ** W
                * probability_of_loss(p - q) ** L
                * L_Q[q_idx]
                * L_P[p_idx]
            )

    normalization_factor = np.sum(U_p)
    U_p /= normalization_factor

    return omega, U_p


@njit
def probability_of_loss_numba(x):
    return 1 / (1 + np.exp(x / 67))


@njit
def U_p_law_numba(W, L, L_P, L_Q):
    omega = np.arange(0, 3501, 10, dtype=np.float64)

    U_p = np.zeros_like(omega)

    for p_idx, p in enumerate(omega):
        for q_idx, q in enumerate(omega):
            U_p[p_idx] += (
                probability_of_loss_numba(q - p) ** W
                * probability_of_loss_numba(p - q) ** L
                * L_Q[q_idx]
                * L_P[p_idx]
            )

    normalization_factor = np.sum(U_p)
    U_p /= normalization_factor

    return omega, U_p


omega_1, U_p_1 = U_p_law(W, L, L_P, L_Q)
omega_2, U_p_2 = U_p_law_numba(W, L, L_P, L_Q)

assert np.allclose(omega_1, omega_2)
assert np.allclose(U_p_1, U_p_2)

t1 = timeit("U_p_law(W, L, L_P, L_Q)", number=10, globals=globals())
t2 = timeit("U_p_law_numba(W, L, L_P, L_Q)", number=10, globals=globals())

print("10 calls using vanilla Python :", t1)
print("10 calls using Numba          :", t2)

在我的机器上打印(AMD 5700x):

10 calls using vanilla Python : 2.46206750581041
10 calls using Numba          : 0.014143474865704775

加速~170x

© www.soinside.com 2019 - 2024. All rights reserved.