是否可以使用Numpy实现此版本的矩阵乘法?

问题描述 投票:3回答:2

我希望快速评估下面的函数,该函数在很大程度上类似于矩阵乘法。对于大型矩阵,以下实现比矩阵的numpy乘法慢几个数量级,这使我相信有更好的方法使用numpy来实现。有什么方法可以使用numpy函数而不是for循环来实现吗?我正在使用的矩阵在每个维度上的范围都在10K-100K范围内,因此此优化非常必要。

一种方法是创建3D数组,但事实证明该数组太大,无法存储我正在使用的矩阵。我也调查了似乎不合适的np.vectorize。

非常感谢您的指导。

编辑:谢谢大家的出色见解和解答。非常感谢您的投入。将日志移出循环可以大大改善运行时间,并且有趣的是,保存k查找非常重要。如果可以的话,我将进行跟进:如果内部循环表达式变为C[i,j] += A[i,k] * np.log(A[i,k] + B[k,j]),您能看到加速的方法吗?可以像以前一样将日志移到外部,但是只有对A[i,k]取幂时,这是昂贵的,并且消除了从移出日志中获得的收益。

import numpy as np
import numba
from numba import njit, prange

@numba.jit(fastmath=True, parallel=True)
def f(A, B):
    
    C = np.zeros((A.shape[0], B.shape[1]))

    for i in prange(A.shape[0]):
        for j in prange(B.shape[1]):
            for k in prange(A.shape[1]):
                
                C[i,j] += np.log(A[i,k] + B[k,j])
                #matrix mult. would be: C[i,j] += A[i,k] * B[k,j]

    return C
python arrays numpy matrix-multiplication numba
2个回答
2
投票

log(a) + log(b) == log(a * b)起,您可以通过用乘法替换加法并仅在最后进行对数来节省大量对数计算,这将节省大量时间。

import numpy as np
import numba as nb

@nb.njit(fastmath=True, parallel=True)
def f(A, B):
    C = np.ones((A.shape[0], B.shape[1]), A.dtype)
    for i in nb.prange(A.shape[0]):
        for j in nb.prange(B.shape[1]):
            # Accumulate product
            for k in nb.prange(A.shape[1]):
                C[i,j] *= (A[i,k] + B[k,j])
    # Apply logarithm at the end
    return np.log(C)

# For comparison
@nb.njit(fastmath=True, parallel=True)
def f_orig(A, B):
    C = np.zeros((A.shape[0], B.shape[1]), A.dtype)
    for i in nb.prange(A.shape[0]):
        for j in nb.prange(B.shape[1]):
            for k in nb.prange(A.shape[1]):
                C[i,j] += np.log(A[i,k] + B[k,j])
    return C

# Test
np.random.seed(0)
a, b = np.random.random((1000, 100)), np.random.random((100, 2000))
print(np.allclose(f(a, b), f_orig(a, b)))
# True
%timeit f(a, b)
# 36.2 ms ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_orig(a, b)
# 296 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

2
投票

正如@jdehesa所指出的,您可以使用以下简化:log(a) + log(b) == log(a * b)但是请注意,结果可能会有很大的不同。另外,有许多方法可以优化此功能。最佳解决方案取决于输入矩阵的大小和所需的数值稳定性。

使用标量并在转置数组上工作(可能会进行自动SIMD矢量化)

import numpy as np

#from version 0.43 until 0.47 this has to be set before importing numba
#Bug: https://github.com/numba/numba/issues/4689
from llvmlite import binding
binding.set_option('SVML', '-vector-library=SVML')
import numba as nb

@nb.njit(fastmath=True,parallel=True)
def f_orig(A, B):
    C = np.zeros((A.shape[0], B.shape[1]))

    for i in nb.prange(A.shape[0]):
        for j in range(B.shape[1]):
            for k in range(A.shape[1]):
                C[i,j] += np.log(A[i,k] + B[k,j])
                #matrix mult. would be: C[i,j] += A[i,k] * B[k,j]

    return C

@nb.njit(fastmath=True,parallel=True)
def f_pre_opt(A, B):
    C = np.empty((A.shape[0], B.shape[1]))
    B_T=np.ascontiguousarray(B.T)

    for i in nb.prange(A.shape[0]):
        for j in range(B_T.shape[0]):
            acc=1.
            for k in range(A.shape[1]):
                acc*=(A[i,k] + B_T[j,k])
            C[i,j] = np.log(acc)

    return C

@nb.njit(fastmath=True, parallel=True)
def f_jdehesa(A, B):
    C = np.ones((A.shape[0], B.shape[1]), A.dtype)
    for i in nb.prange(A.shape[0]):
        for j in nb.prange(B.shape[1]):
            # Accumulate product
            for k in nb.prange(A.shape[1]):
                C[i,j] *= (A[i,k] + B[k,j])
    # Apply logarithm at the end
    return np.log(C)

Timings

# Test
np.random.seed(0)
a, b = np.random.random((1000, 100)), np.random.random((100, 2000))

res_1=f_orig(a, b)
res_2=f_pre_opt(a, b)
res_3=f_jdehesa(a, b)

# True
%timeit f_orig(a, b)
#262 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_pre_opt(a, b)
#12.4 ms ± 405 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit f_jdehesa(a, b)
#41 ms ± 2.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

对于较大的矩阵,此解决方案远非最佳。为了更好地使用缓存,还需要进行其他优化,例如逐块计算结果。

Real world implementation of a matrix-matrix multiplication

© www.soinside.com 2019 - 2024. All rights reserved.