我目前正在编写代码,将大量粒子(总共 10k)投影到 MLS(移动最小二乘)表面上。此任务需要调用 lstsq 函数 10k 次,每个矩阵的形状类似于
(N,6)
(N 在 200 到 1k 之间,具体取决于粒子的位置)。
目前,我使用
torch.linalg.lstsq
进行计算,每次调用大约需要半秒。因此,我正在寻找一种更有效的方法来完成这项任务。如果有任何代码示例或推荐的库,那将非常有帮助。
我尝试过一些方法:
numpy.linalg.lstsq
带 for 循环。我花了大约1.3秒。
scipy.linalg.lstsq
带有 gelsy
LAPACK 驱动程序和 for 循环。我也花了大约1.3秒。
SVD 方法。我花了大约 1.5 秒,看起来像:
u, s, v = np.linalg.svd(A, full_matrices=False)
uTb = np.einsum('ijk,ij->ik', u, b)
c = np.einsum('ijk,ij->ik', v, uTb / s)
return c
使用
np.linalg.solve
来解决A.T@Ax=A.T@b
。我花了大约 1.5 秒,看起来像:
ATA = np.einsum('ijk,ijl->ikl', A, A)
ATb = np.einsum('ijk,ij->ik', A, b)
c = np.linalg.solve(ATA, ATb)
return c
多线程。由于我的代码的另一部分使用了 Taichi 库,当尝试使用多线程时,会初始化多个 Taichi 后端,导致我的计算机死机。
from multiprocessing import Pool
with Pool() as pool:
c = pool.starmap(calc_lstsq, [(A[i], b[i]) for i in range(b.shape[0])])
return np.asarray(c)
我推荐 JAX,使用你的正规方程示例,我获得了约 10 倍的加速:
import numpy as np
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax
rng = np.random.default_rng()
A = rng.random((10000, 1000, 6))
b = rng.random((10000, 1000))
def normal_np(A, b):
ATA = np.einsum('ijk,ijl->ikl', A, A)
ATb = np.einsum('ijk,ij->ik', A, b)
return np.linalg.solve(ATA, ATb)
@jax.jit
def normal_jax(A, b):
ATA = jnp.einsum('ijk,ijl->ikl', A, A)
ATb = jnp.einsum('ijk,ij->ik', A, b)
return jnp.linalg.solve(ATA, ATb)
时间:
assert np.allclose(normal_np(A, b), normal_jax(A, b))
%timeit normal_np(A, b)
%timeit normal_jax(A, b).block_until_ready()
输出:
882 ms ± 6.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
81.5 ms ± 294 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)