考虑以下代码块
N = 28 * 28
X = rng.randn(10000, N)
n_groups = group_size = 28
Q = X[:10]
Z = X[:, None] * Q[None] # line 4: multiply every row of Q by every row of X
Z = Z.reshape((len(X), len(Q), n_groups, group_size)).mean(axis=3)
问题。 如何重新实现上述代码片段以对 Z 输出相同的结果,但无需在第 4 行执行昂贵的(内存方面等)操作?
我希望这可以通过某种本机张量或多维点积实现。
提前致谢。
在 Numpy 中执行此操作的快速标准方法是使用
einsum
:
X2 = X.reshape(len(X), n_groups, group_size)
Q2 = Q.reshape(len(Q), n_groups, group_size)
Z = np.einsum('ikl,jkl->ijk', X2, Q2, optimize=True) / group_size
这不仅明显更快,而且内存效率更高,因为无需创建“临时数组”。请注意,在这种情况下 einsum
并不是最佳选择,因为最后一个维度相当小并且是顺序执行的。如果速度不够快,可以编写优化的并行 Numba/Cython 代码以获得更好的性能。
/ group_size
可以应用于
Q2
而不是 np.einsum
的结果,以获得更好的性能(因为 Q2
更小,这在数学上是等效的)。
基准Initial implementation: 167 ms
Naive einsum: 53 ms
Optimized einsum: 49 ms
einsum
是一个高度广义的矩阵乘积函数,而您所拥有的只是最后两个维度上的标准矩阵乘积。
尝试:import time
import numpy as np
from numpy.random import default_rng
n_groups = group_size = 28
N = n_groups * group_size
rng = default_rng(seed=0)
X = rng.random((10_000, N))
Q = X[:10, :]
a = time.perf_counter()
xx, qq = np.broadcast_arrays(
X.reshape((len(X), 1, -1)),
Q.reshape((1, len(Q), -1)),
)
Z = (
xx.reshape((len(X), len(Q), n_groups, 1, group_size)) @
qq.reshape((len(X), len(Q), n_groups, group_size, 1))
)[...,0,0] / n_groups
b = time.perf_counter()
print(b - a)
与
einsum
的 167->49 ~ 3.4 相比,这在我的机器上产生了 296ms->80ms ~ 3.7 的加速。