优化这个Python函数 - 应用基于索引奇偶校验的线性变换

问题描述 投票:0回答:1

我有这个简单的Python函数:

import numpy as np

def fast_transform(img, offset, factor):
    rep = (img.shape[0]//2, img.shape[1]//2)
    out = (img.astype(np.float32) - np.tile(offset, rep)) * np.tile(factor, rep)
    return out

该函数获取一个图像(作为 NXM numpy ndarray)和两个 2x2 数组(偏移量和因子)。然后,它根据每个维度的奇偶性计算图像中每个像素的基本线性变换:

out[i,j] = (out[i,j] - offset[i%2,j%2]) * factor[i%2,j%2]

如您所见,我使用 np.tile 来尝试加速该功能,但这对于我的需求来说还不够快(而且我认为虚拟 np.tile 数组的创建使其不是最佳的)。我尝试使用 numba,但它还不支持 np.tile。

你能帮我尽可能优化这个功能吗?我确信我缺少一些简单的方法来做到这一点。

python numpy optimization numba
1个回答
0
投票

您尝试使用

numba
的幼稚方法:

import numpy as np
from numba import njit, prange


@njit
def fast_transform_numba(img, offset, factor):
    out = np.empty(img.shape, dtype=np.float32)
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
    return out

看看 for 循环,它可以轻松并行化:

@njit(parallel=True)
def fast_transform_numba_parallel(img, offset, factor):
    out = np.empty(img.shape, dtype=np.float32)
    for i in prange(img.shape[0]):
        for j in prange(img.shape[1]):
            out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
    return out

基准:

import perfplot
from matplotlib import pyplot as plt

plt.rcParams["figure.autolayout"] = True
np.random.seed(0)

def fast_transform(img, offset, factor):
    rep = (img.shape[0] // 2, img.shape[1] // 2)
    out = (img.astype(np.float32) - np.tile(offset, rep)) * np.tile(factor, rep)
    return out

img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 16, 17]])
offset = np.array([[2, 1], [2, 1]])
factor = np.array([[2, 2], [2, 2]])

x = fast_transform(img, factor, offset)
print(x)

# this compiles the function too:
x = fast_transform_numba(img, factor, offset)
print(x)

# this compiles the function too:
x = fast_transform_numba_parallel(img, factor, offset)
print(x)

perfplot.show(
    setup=lambda n: np.random.randint(0, 255, (n, n), dtype=np.uint8),
    kernels=[
        lambda img: fast_transform(img, offset, factor),
        lambda img: fast_transform_numba(img, offset, factor),
        lambda img: fast_transform_numba_parallel(img, offset, factor),
    ],
    labels=["fast_transform", "fast_transform_numba", "fast_transform_numba_parallel"],
    n_range=[2**k for k in range(2, 15)],
    xlabel="img(N * N)",
    logx=True,
    logy=True,
)

创建此图表:

您可以看到,在特定阈值之后,并行方法似乎是最快的。

© www.soinside.com 2019 - 2024. All rights reserved.