如何加速卷积函数?

问题描述 投票:0回答:1

我写了一个这样的卷积函数:

import numpy as np 
import numba as nb

# Generate sample input data
num_chans = 111
num_bins = 47998 
num_rad = 8
num_col = 1000

rng = np.random.default_rng()

wvl_sensor = rng.uniform(low=1000, high=11000, size=(num_chans, num_col))
fwhm_sensor = rng.uniform(low=0.01, high=2.0, size=num_chans)

wvl_lut = rng.uniform(low=1000, high=11000, size=num_bins) 
rad_lut = rng.uniform(low=0, high=1, size=(num_rad, num_bins))

# Original convolution implementation
def original_convolve(wvl_sensor, fwhm_sensor, wvl_lut, rad_lut):

    sigma = fwhm_sensor / (2.0 * np.sqrt(2.0 * np.log(2.0)))  
    var = sigma ** 2
    denom = (2 * np.pi * var) ** 0.5
    
    numer = np.exp(-(wvl_lut[:, None] - wvl_sensor[None, :])**2 / (2*var)) 
    response = numer / denom
    
    response /= response.sum(axis=0)
    resampled = np.dot(rad_lut, response)
    
    return resampled

numpy版本运行大约45秒:

# numpy version
num_chans, num_col = wvl_sensor.shape
num_bins = wvl_lut.shape[0]
num_rad = rad_lut.shape[0]

original_res = np.empty((num_col, num_rad, num_chans), dtype=np.float64)

for x in range(wvl_sensor.shape[1]):
    original_res[x, :, :] = original_convolve(wvl_sensor[:, x], fwhm_sensor, wvl_lut, rad_lut)

我尝试使用 numba 来加速它:

@nb.jit(nopython=True)
def numba_convolve(wvl_sensor, fwhm_sensor, wvl_lut, rad_lut):
    num_chans, num_col = wvl_sensor.shape
    num_bins = wvl_lut.shape[0]
    num_rad = rad_lut.shape[0]

    output = np.empty((num_col, num_rad, num_chans), dtype=np.float64)

    sigma = fwhm_sensor / (2.0 * np.sqrt(2.0 * np.log(2.0)))  
    var = sigma ** 2
    denom = (2 * np.pi * var) ** 0.5

    for x in nb.prange(num_col):
        numer = np.exp(-(wvl_lut[:, None] - wvl_sensor[None, :, x])**2 / (2*var))
        response = numer / denom

        response /= response.sum(axis=0)
        resampled = np.dot(rad_lut, response)
        output[x, :, :] = resampled

    return output

仍然需要32s左右。请注意,如果我使用

@nb.jit(nopython=True, parallel=True)
,则输出全部为零值。

有正确应用 numba 的想法吗?或者改进卷积函数?

python arrays numpy matrix numba
1个回答
0
投票

这并不是一个真正的答案,但我还不能发表评论。我尝试重现您的结果,但运行 numba 编译的函数对我来说失败了(我认为是因为使用

None
进行了切片)。

但是,您是否只运行了这两个函数一次还是多次?在第一次运行时,编译造成的时间损失可能会导致结果出现偏差。如果您还没有尝试多次迭代,请尝试在循环中多次运行这些函数,或者使用例如

timeit

© www.soinside.com 2019 - 2024. All rights reserved.