我有这个简单的Python函数:
import numpy as np
def fast_transform(img, offset, factor):
rep = (img.shape[0]//2, img.shape[1]//2)
out = (img.astype(np.float32) - np.tile(offset, rep)) * np.tile(factor, rep)
return out
该函数获取一个图像(作为 NXM numpy ndarray)和两个 2x2 数组(偏移量和因子)。然后,它根据每个维度的奇偶性计算图像中每个像素的基本线性变换:
out[i,j] = (out[i,j] - offset[i%2,j%2]) * factor[i%2,j%2]
如您所见,我使用 np.tile 来尝试加速该功能,但这对于我的需求来说还不够快(而且我认为虚拟 np.tile 数组的创建使其不是最佳的)。我尝试使用 numba,但它还不支持 np.tile。
你能帮我尽可能优化这个功能吗?我确信我缺少一些简单的方法来做到这一点。
您尝试使用
numba
的幼稚方法:
import numpy as np
from numba import njit, prange
@njit
def fast_transform_numba(img, offset, factor):
out = np.empty(img.shape, dtype=np.float32)
for i in range(img.shape[0]):
for j in range(img.shape[1]):
out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
return out
看看 for 循环,它可以轻松并行化:
@njit(parallel=True)
def fast_transform_numba_parallel(img, offset, factor):
out = np.empty(img.shape, dtype=np.float32)
for i in prange(img.shape[0]):
for j in prange(img.shape[1]):
out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
return out
基准:
import perfplot
from matplotlib import pyplot as plt
plt.rcParams["figure.autolayout"] = True
np.random.seed(0)
def fast_transform(img, offset, factor):
rep = (img.shape[0] // 2, img.shape[1] // 2)
out = (img.astype(np.float32) - np.tile(offset, rep)) * np.tile(factor, rep)
return out
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 16, 17]])
offset = np.array([[2, 1], [2, 1]])
factor = np.array([[2, 2], [2, 2]])
x = fast_transform(img, factor, offset)
print(x)
# this compiles the function too:
x = fast_transform_numba(img, factor, offset)
print(x)
# this compiles the function too:
x = fast_transform_numba_parallel(img, factor, offset)
print(x)
perfplot.show(
setup=lambda n: np.random.randint(0, 255, (n, n), dtype=np.uint8),
kernels=[
lambda img: fast_transform(img, offset, factor),
lambda img: fast_transform_numba(img, offset, factor),
lambda img: fast_transform_numba_parallel(img, offset, factor),
],
labels=["fast_transform", "fast_transform_numba", "fast_transform_numba_parallel"],
n_range=[2**k for k in range(2, 15)],
xlabel="img(N * N)",
logx=True,
logy=True,
)
创建此图表:
您可以看到,在特定阈值之后,并行方法似乎是最快的。