为什么dgemm(Cython编译)比numpy.dot慢

问题描述 投票:1回答:1

长话短说,我在Cython中构建了一个简单的乘法函数,调用scipy.linalg.cython_blas.dgemm,对其进行编译并针对基准Numpy.dot运行它。我听说过有关使用静态定义,数组维预分配,内存视图,关闭检查等技巧的性能提高50%至100倍的神话,但是后来我写了自己的my_dot函数(编译后),比默认Numpy.dot 慢4倍。我真的不知道是什么原因,所以我只能猜测:

1)BLAS库未链接

2)可能有一些内存开销没有被我捕获

3)dot使用了一些隐藏的魔法

4)编写错误的setup.py,并且c代码未得到最佳编译

5)我的my_dot功能没有被有效地编写

下面是我的代码段,我能想到的所有相关信息都可以帮助解决这个难题。如果有人能提供关于我做错了什么或如何将性能提高到至少与默认设置Numpy.dot相提并论的信息,我将不胜感激。

文件1:model_cython/multi.pyx。您还将在文件夹中也需要model_cython/init.py

#cython: language_level=3 
#cython: boundscheck=False
#cython: nonecheck=False
#cython: wraparound=False
#cython: infertypes=True
#cython: initializedcheck=False
#cython: cdivision=True
#distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration


from scipy.linalg.cython_blas cimport dgemm
import numpy as np
from numpy cimport ndarray, float64_t
from numpy cimport PyArray_ZEROS
cimport numpy as np
cimport cython

np.import_array()
ctypedef float64_t DOUBLE

def my_dot(double [::1, :] a, double [::1, :] b, int ashape0, int ashape1, 
        int bshape0, int bshape1):
    cdef np.npy_intp cshape[2]
    cshape[0] = <np.npy_intp> ashape0
    cshape[1] = <np.npy_intp> bshape1

    cdef:
        int FORTRAN = 1
        ndarray[DOUBLE, ndim=2] c = PyArray_ZEROS(2, cshape, np.NPY_DOUBLE, FORTRAN)

    cdef double alpha = 1.0
    cdef double beta = 0.0
    dgemm("N", "N", &ashape0, &bshape1, &ashape1, &alpha, &a[0,0], &ashape0, &b[0,0], &bshape0, &beta, &c[0,0], &ashape0)
    return c

文件2:model_cython/example.py。执行基准测试的脚本

setup_str = """
import numpy as np
from numpy import float64
from multi import my_dot

a = np.ones((2,3), dtype=float64, order='F')
b = np.ones((3,2), dtype=float64, order='F')
print(a.flags)
ashape0, ashape1 = a.shape
bshape0, bshape1 = b.shape
"""
import timeit
print(timeit.timeit(stmt='c=my_dot(a,b, ashape0, ashape1, bshape0, bshape1)', setup=setup_str, number=100000))
print(timeit.timeit(stmt='c=a.dot(b)', setup=setup_str, number=100000))

文件3:setup.py。编译.so文件

from distutils.core import setup, Extension
from Cython.Build import cythonize
from Cython.Distutils import build_ext
import numpy 
import os
basepath = os.path.dirname(os.path.realpath(__file__))
numpy_path = numpy.get_include()
package_name = 'multi'
setup(
        name='multi',
        cmdclass={'build_ext': build_ext},
        ext_modules=[Extension(package_name, 
            [os.path.join(basepath, 'model_cython', 'multi.pyx')], 
            include_dirs=[numpy_path],
            )],
        )

文件4:run.sh。执行setup.py并移动内容的Shell脚本

python3 setup.py build_ext --inplace
path=$(pwd)
rm -r build
mv $path/multi.cpython-37m-darwin.so $path/model_cython/
rm $path/model_cython/multi.c

下面是编译消息的屏幕截图:

clang compilation

关于BLAS,我的Numpy已正确链接到/usr/local/lib,并且clang -bundle似乎也在编译中添加了-L/usr/local/lib。但这还不够吗?

长话短说,我在Cython中构建了一个简单的乘法函数,调用scipy.linalg.cython_blas.dgemm,对其进行编译并针对基准Numpy.dot运行它。我听过关于...

python c numpy cython matrix-multiplication
1个回答
0
投票

Cython擅长优化循环(在Python中通常很慢),它也是调用C的便捷方法(这是您要执行的操作)。但是,从Python调用Cython函数可能相对较慢-尤其是因为需要检查您指定的所有类型的一致性。因此,您通常会尝试在一个Cython调用之后隐藏大量工作,以减少开销。

© www.soinside.com 2019 - 2024. All rights reserved.