加快scipy.sparse.csr_matrix的左matmul

问题描述 投票:0回答:1

我需要执行以下矩阵乘法:x * A[idx]其中Ascipy.sparse.csr_matrix,而idxnp.array索引。由于建立索引,我无法将其更改为csc_matrix。即使A[idx] * u,它似乎也比右矩阵乘法len(idx) << A.shape[1]慢至少50倍。如何加快速度?

python scipy sparse-matrix
1个回答
0
投票

由于我没有在其他地方找到解决方案,所以这就是我最终要做的。

使用numba,我写道:

from numba import njit
@njit
def csr_left_mul(x, data, indptr, indices, d, idx):
    """
    :param x: numpy.ndarray
    :param A: scipy.sparse.csr_matrix
    :return: x.dot(A)
    """
    res = np.zeros(d)
    assert x.shape[0] == len(idx)
    for k, i in np.ndenumerate(idx):
        for j in range(indptr[i], indptr[i+1]):
            j_idx = indices[j]
            res[j_idx] += x[k] * data[j]
    return res
© www.soinside.com 2019 - 2024. All rights reserved.