我有一个包含 2 个嵌套 for 循环的代码,想用多线程运行它,看看它是否运行得更快?这是我的代码:
import cython
import numpy as np
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int test_speed(unsigned char[:, :] Rint, unsigned char[:, :] Gint, unsigned char[:, :] Bint, unsigned char[:, :] table):
cdef unsigned char[:, :] R_ch_new=np.zeros((1024, 1024), dtype="uint8")
cdef unsigned char[:, :] B_ch_new=np.zeros((1024, 1024), dtype="uint8")
cdef unsigned char[:, :] G_ch_new=np.zeros((1024, 1024), dtype="uint8")
cdef int i, j
for i in range(1024):
for j in range(1024):
# r = np.round(Rint[i,j])
R_ch_new[i,j]=table[0, Rint[i,j]]
# s = np.round(Gint[i,j])
G_ch_new[i,j]=table[0, Gint[i,j]]
# t = np.round(Bint[i,j])
B_ch_new[i,j]=table[0, Bint[i,j]]
return 1
我尝试了这段代码,没有注意到任何加速,还检查了进程线程,只有 1 个线程忙于 12-14%,所以这是正常工作还是我做错了什么?
import cython
import numpy as np
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int test_speed(unsigned char[:, :] Rint, unsigned char[:, :] Gint, unsigned char[:, :] Bint, unsigned char[:, :] table):
cdef unsigned char[:, :] R_ch_new=np.zeros((1024, 1024), dtype='uint8')
cdef unsigned char[:, :] B_ch_new=np.zeros((1024, 1024), dtype='uint8')
cdef unsigned char[:, :] G_ch_new=np.zeros((1024, 1024), dtype='uint8')
cdef int i, j
for i in prange(1024, nogil=True):
for j in range(1024):
# r = np.round(Rint[i,j])
R_ch_new[i,j]=table[0, Rint[i,j]]
# s = np.round(Gint[i,j])
G_ch_new[i,j]=table[0, Gint[i,j]]
# t = np.round(Bint[i,j])
B_ch_new[i,j]=table[0, Bint[i,j]]
return 1
这就是我测试速度的方式:
import time
import numpy as np
from testspeed1 import test_speed
table=np.zeros((1, 256), dtype='uint8')
Rint=np.zeros((1024, 1024), dtype='uint8')
Gint=np.zeros((1024, 1024), dtype='uint8')
Bint=np.zeros((1024, 1024), dtype='uint8')
for i in range(1024):
for j in range(1024):
Rint[i, j] = np.random.randint(0, 255)
Gint[i, j] = np.random.randint(0, 255)
Bint[i, j] = np.random.randint(0, 255)
for i in range(256):
table[0, i] = np.random.randint(0, 255)
start = time.time()
result = test_speed(Rint, Gint, Bint, table)
end = time.time()
print('duration : ', str(end - start))
***更新: 我使用这个 setup.py :
from distutils.core import setup
import numpy
from Cython.Build import cythonize
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_modules = [
Extension("testspeed1",
["testspeed1.pyx"],
extra_compile_args=['-fopenmp'],
include_dirs=[numpy.get_include()]
)
]
setup(
name="testspeed1",
ext_modules=cythonize(ext_modules),
include_dirs=[numpy.get_include()]
)
并使用以下命令运行:
python setup.py build_ext --inplace
一般来说,我会在安装文件中添加它:
if sys.platform.startswith("win"):
openmp_arg = '-openmp'
opt_compiler = '/O2'
else:
openmp_arg = '-fopenmp'
opt_compiler = '-O3'
ext_modules = [Extension("name",
["source to pyx"],
extra_compile_args=[openmp_arg, opt_compiler],
extra_link_args=[openmp_arg] if '-f' in openmp_arg else [])]
干杯