我编写了一个函数 U_p_law,它采用 2 个玩家的 2 个概率密度函数(L_P 和 L_Q)和 2 个整数值,这些整数值定义每个玩家互相比赛时的频率。
目前的计算是通过 2 个嵌套循环完成的,每个循环都通过 omega 进行索引(一个从 0 到 3500 的数组,步骤 10)。该函数返回第一个玩家的标准化效用函数:
def U_p_law(W,L,L_P, L_Q):
omega = np.arange(0, 3501, 10)
U_p = np.zeros_like(omega, dtype=float)
for p_idx, p in enumerate(omega):
for q_idx, q in enumerate(omega):
U_p[p_idx] += (probability_of_loss(q - p)**W * probability_of_loss(p - q)**L * L_Q[q_idx] * L_P[p_idx])
normalization_factor = np.sum(U_p)
U_p /= normalization_factor
return omega, U_p
omega, U_p = U_p_law(W, L, L_P, L_Q)
由于这些嵌套循环计算起来很耗时,我想知道是否有更好的方法来编写这个函数?
嵌套循环方法给了我想要的结果,但需要计算时间。 我将这些值传递给函数:
P_mean = 1500
P_std = 100
Q_mean = 1500
Q_std = 100
W = 1 # Number of matches won by P
L = 0 # Number of matches lost by P
L_P = np.exp(-0.5 * ((np.arange(0, 3501, 10) - P_mean) / P_std)**2) / (P_std * np.sqrt(2 * np.pi))
L_Q = np.exp(-0.5 * ((np.arange(0, 3501, 10) - Q_mean) / Q_std)**2) / (Q_std * np.sqrt(2 * np.pi))
def probability_of_loss(x):
return 1 / (1 + np.exp(x / 67))
提前致谢
您可以使用 numba 加速该功能。这是一个示例+简单的基准测试:
from timeit import timeit
from numba import njit
P_mean = 1500
P_std = 100
Q_mean = 1500
Q_std = 100
W = 1 # Number of matches won by P
L = 0 # Number of matches lost by P
L_P = np.exp(-0.5 * ((np.arange(0, 3501, 10) - P_mean) / P_std) ** 2) / (
P_std * np.sqrt(2 * np.pi)
)
L_Q = np.exp(-0.5 * ((np.arange(0, 3501, 10) - Q_mean) / Q_std) ** 2) / (
Q_std * np.sqrt(2 * np.pi)
)
def probability_of_loss(x):
return 1 / (1 + np.exp(x / 67))
def U_p_law(W, L, L_P, L_Q):
omega = np.arange(0, 3501, 10)
U_p = np.zeros_like(omega, dtype=float)
for p_idx, p in enumerate(omega):
for q_idx, q in enumerate(omega):
U_p[p_idx] += (
probability_of_loss(q - p) ** W
* probability_of_loss(p - q) ** L
* L_Q[q_idx]
* L_P[p_idx]
)
normalization_factor = np.sum(U_p)
U_p /= normalization_factor
return omega, U_p
@njit
def probability_of_loss_numba(x):
return 1 / (1 + np.exp(x / 67))
@njit
def U_p_law_numba(W, L, L_P, L_Q):
omega = np.arange(0, 3501, 10, dtype=np.float64)
U_p = np.zeros_like(omega)
for p_idx, p in enumerate(omega):
for q_idx, q in enumerate(omega):
U_p[p_idx] += (
probability_of_loss_numba(q - p) ** W
* probability_of_loss_numba(p - q) ** L
* L_Q[q_idx]
* L_P[p_idx]
)
normalization_factor = np.sum(U_p)
U_p /= normalization_factor
return omega, U_p
omega_1, U_p_1 = U_p_law(W, L, L_P, L_Q)
omega_2, U_p_2 = U_p_law_numba(W, L, L_P, L_Q)
assert np.allclose(omega_1, omega_2)
assert np.allclose(U_p_1, U_p_2)
t1 = timeit("U_p_law(W, L, L_P, L_Q)", number=10, globals=globals())
t2 = timeit("U_p_law_numba(W, L, L_P, L_Q)", number=10, globals=globals())
print("10 calls using vanilla Python :", t1)
print("10 calls using Numba :", t2)
在我的机器上打印(AMD 5700x):
10 calls using vanilla Python : 2.46206750581041
10 calls using Numba : 0.014143474865704775
加速~170x