有什么方法可以加速我的稀疏矩阵乘法的 NUMBA 实现吗?

问题描述 投票:0回答:1

NUMBA 不支持稀疏矩阵,我想找到一种方法以 COO 格式编写自己的 spM 乘法函数。

import numpy as np
from numba import njit, prange
from numba.core.types import ListType, int32, DictType, float64
from numba.typed import Dict, List
from scipy.sparse import coo_matrix

LIST_TYPE_TRIP = ListType(int32)
DICT_TYPE_TRIP = DictType(keyty=int32, valty=float64)
TOLERANCE = 1e-6

@njit
def triplet2lookup(bc: np.ndarray, m: int) -> tuple[np.ndarray, np.ndarray]:
    """
    :param bc: column index of sparse matrix b
    :param m: num of cols in sparse matrix b
    :return:
    """
    table = np.zeros((m, m), dtype=np.int32)
    count = np.zeros(m, dtype=np.int32)
    for i in range(len(bc)):
        val = bc[i]
        table[val, count[val]] = i
        count[val] += 1
    return table, count


@njit(fastmath=True)
def _mul(ar, ac, av, br, bc, bv, an, bm):
    na = len(av)
    rr = np.empty(len(ar) * 2, dtype=np.int32)
    rc = np.empty(len(rr), dtype=np.int32)
    rv = np.empty(len(rr))
    table, count = triplet2lookup(bc, bm)
    cnt = 0
    hash_mat = np.zeros((an, bm))
    for i in range(na):
        target = ac[i]
        for j in range(count[target]):
            row_idx = ar[i]
            col_idx = br[table[target, j]]
            rr[cnt] = row_idx
            rc[cnt] = col_idx
            rv[cnt] = av[i] * bv[table[target, j]]
            cnt += 1
            if hash_mat[row_idx, col_idx] < TOLERANCE:
                rr[cnt] = row_idx
                rc[cnt] = col_idx
                rv[cnt] = av[i] * bv[table[target, j]]
                cnt += 1
            else:
                rv[cnt] += av[i] * bv[table[target, j]]
            if cnt >= len(rr):
                rr, rc, rv = extend_arr(rr, rc, rv)

    rr = rr[:cnt]
    rc = rc[:cnt]
    rv = rv[:cnt]
    return rr, rc, rv


@njit
def _extend_arr(rr, dtype):
    tmp_rr = np.empty(len(rr) * 2, dtype=dtype)
    tmp_rr[:len(rr)] = rr
    return tmp_rr


@njit
def extend_arr(rr, rc, rv):
    rr = _extend_arr(rr, np.int32)
    rc = _extend_arr(rc, np.int32)
    rv = _extend_arr(rv, np.float64)
    return rr, rc, rv

与 scipy 进行速度比较后,我发现 scipy 对于大矩阵要快得多。

n = 50000
m = 1000
row_a = np.random.randint(0, m, n)
col_a = np.random.randint(0, m, n)
val_a = np.random.random(n)

row_b = np.random.randint(0, m, n)
col_b = np.random.randint(0, m, n)
val_b = np.random.random(n)

a_sci = coo_matrix((val_a, (row_a, col_a)), shape=(m, m))
b_sci = coo_matrix((val_b, (row_b, col_b)), shape=(m, m))
%timeit scipy_res = a_sci @ b_sci
# 21.4 ms ± 722 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit rr, rc, rv = _mul(row_a, col_a, val_a, col_b, row_b, val_b, m, m)
202 ms ± 47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

在我的机器上,scipy 几乎快了 10 倍。所以我想知道是否有像SIMD或并行计算之类的东西可以用来加速_mul函数?

我尝试通过将 @njit 更改为 @njit(parallel=True) 并将范围 -> prange 更改为内部或外部循环来并行化 _mul 函数,但这会减慢算法速度。

python matrix scipy sparse-matrix numba
1个回答
0
投票

我想出了一个相当简单的方法来做到这一点,它只比 SciPy 慢 15%,并且仍然允许您使用 nopython 模式。基本思想是调用 SciPy 的矩阵乘法。

如果您以最直接的方式(使用

@jit(forceobj=True)
)执行此操作,则会出现问题:任何调用
_mul()
的函数都无法确定其返回类型,因此无法在 nopython 模式下运行。任何调用该函数的函数都不能在 nopython 模式下运行,等等

我尝试通过用签名指定返回类型来解决这个问题,但我不知道如何获得

njit()
函数来使用它。最终,我找到了
nb.objmode()
,它可以指定允许在更广泛的 Numba 函数中运行 Python 函数的代码块。调用此函数的函数仍然可以用
njit()
标记。

import numba as nb

@nb.njit()
def mul2(ar, ac, av, br, bc, bv, n):
    with nb.objmode(row='i4[:]', col='i4[:]', data='f8[:]'):
        a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
        b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
        result = (a_sci @ b_sci).tocoo()
        row = result.row
        col = result.col
        data = result.data
    return row, col, data

警告:objmode 的文档表示,通常有比使用 objmode 更好的方法来实现您的目标。可能有更好的方法来做到这一点 - 我在 Numba 方面没有太多经验。

为什么这比直接调用 SciPy 慢?问题是您想要 COO 格式的结果,但矩阵乘法产生 CSR 格式的结果。转换需要额外的时间。

我还尝试了一个在 CSR 中输出的选项。

@nb.njit()
def mul3(ar, ac, av, br, bc, bv, n):
    with nb.objmode(data='f8[:]', indptr='i4[:]', indices='i4[:]'):
        a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
        b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
        result = a_sci @ b_sci
        data = result.data
        indptr = result.indptr
        indices = result.indices
    return data, indices, indptr

此版本仅比直接 SciPy 选项慢 2%。

这些函数受益于 SciPy 中的优化算法,同时能够从使用 nopython 模式执行所有其他操作的 Numba 函数调用。

时间:

Pure SciPy
27.8 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Original Numba
184 ms ± 594 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numba SciPy COO output
32 ms ± 228 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Numba SciPy CSR output
28.3 ms ± 685 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
© www.soinside.com 2019 - 2024. All rights reserved.