如何在Python中高效计算lstsq 10k次?

问题描述 投票:0回答:1

我目前正在编写代码,将大量粒子(总共 10k)投影到 MLS(移动最小二乘)表面上。此任务需要调用 lstsq 函数 10k 次,每个矩阵的形状类似于

(N,6)
(N 在 200 到 1k 之间,具体取决于粒子的位置)。

目前,我使用

torch.linalg.lstsq
进行计算,每次调用大约需要半秒。因此,我正在寻找一种更有效的方法来完成这项任务。如果有任何代码示例或推荐的库,那将非常有帮助。

我尝试过一些方法:

  1. numpy.linalg.lstsq
    带 for 循环。我花了大约1.3秒。

  2. scipy.linalg.lstsq
    带有
    gelsy
    LAPACK 驱动程序和 for 循环。我也花了大约1.3秒。

  3. SVD 方法。我花了大约 1.5 秒,看起来像:

    u, s, v = np.linalg.svd(A, full_matrices=False)
    uTb = np.einsum('ijk,ij->ik', u, b)
    c = np.einsum('ijk,ij->ik', v, uTb / s)
    return c
    
  4. 使用

    np.linalg.solve
    来解决
    A.T@Ax=A.T@b
    。我花了大约 1.5 秒,看起来像:

    ATA = np.einsum('ijk,ijl->ikl', A, A)
    ATb = np.einsum('ijk,ij->ik', A, b)
    c = np.linalg.solve(ATA, ATb)
    return c
    
  5. 多线程。由于我的代码的另一部分使用了 Taichi 库,当尝试使用多线程时,会初始化多个 Taichi 后端,导致我的计算机死机。

    from multiprocessing import Pool
    with Pool() as pool:
          c = pool.starmap(calc_lstsq, [(A[i], b[i]) for i in range(b.shape[0])])
    return np.asarray(c)
    
python numpy pytorch linear-regression linear-algebra
1个回答
0
投票

我推荐 JAX,使用你的正规方程示例,我获得了约 10 倍的加速:

import numpy as np
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax

rng = np.random.default_rng()
A = rng.random((10000, 1000, 6))
b = rng.random((10000, 1000))

def normal_np(A, b):
    ATA = np.einsum('ijk,ijl->ikl', A, A)
    ATb = np.einsum('ijk,ij->ik', A, b)
    return np.linalg.solve(ATA, ATb)

@jax.jit
def normal_jax(A, b):
    ATA = jnp.einsum('ijk,ijl->ikl', A, A)
    ATb = jnp.einsum('ijk,ij->ik', A, b)
    return jnp.linalg.solve(ATA, ATb)

时间:

assert np.allclose(normal_np(A, b), normal_jax(A, b))
%timeit normal_np(A, b)
%timeit normal_jax(A, b).block_until_ready()

输出:

882 ms ± 6.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
81.5 ms ± 294 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.