torch.rfft-基于fft的卷积创建的输出与空间卷积不同

问题描述 投票:1回答:1

我在Pytorch中实现了基于FFT的卷积,并通过conv2d()函数将结果与空间卷积进行了比较。使用的卷积滤波器是平均滤波器。由于期望的平均滤波,conv2d()函数产生了平滑的输出,但是基于fft的卷积返回了更加模糊的输出。我已经在这里附加了代码并输出-

空间卷积-

from PIL import Image, ImageOps
import torch
from matplotlib import pyplot as plt
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import numpy as np

im = Image.open("/kaggle/input/tiger.jpg")
im = im.resize((256,256))
gray_im = im.convert('L') 
gray_im = ToTensor()(gray_im)
gray_im = gray_im.squeeze()

fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])

conv_gray_im = gray_im.unsqueeze(0).unsqueeze(0)
conv_fil = fil.unsqueeze(0).unsqueeze(0)

conv_op = F.conv2d(conv_gray_im,conv_fil)

conv_op = conv_op.squeeze()

plt.figure()
plt.imshow(conv_op, cmap='gray')

基于FFT的卷积-

def fftshift(image):
    sh = image.shape
    x = np.arange(0, sh[2], 1)
    y = np.arange(0, sh[3], 1)
    xm, ym  = np.meshgrid(x,y)
    shifter = (-1)**(xm + ym)
    shifter = torch.from_numpy(shifter)
    return image*shifter

shift_im = fftshift(conv_gray_im)
padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1]))
shift_fil = fftshift(padded_fil)
fft_shift_im = torch.rfft(shift_im, 2, onesided=False)
fft_shift_fil = torch.rfft(shift_fil, 2, onesided=False)
shift_prod = fft_shift_im*fft_shift_fil
shift_fft_conv = fftshift(torch.irfft(shift_prod, 2, onesided=False))

fft_op = shift_fft_conv.squeeze()
plt.figure('shifted fft')
plt.imshow(fft_op, cmap='gray')

原始图像-

original image

空间卷积输出-

convolution output

基于fft的卷积输出-

fft-conv output

有人可以解释这个问题吗?

python image-processing pytorch fft convolution
1个回答
0
投票

您的代码的主要问题是Torch不做复数,其FFT的输出是3D数组,第3维具有两个值,一个用于实数,一个用于虚数。因此,乘法不会执行复杂的乘法。

目前在Torch中没有定义复杂的乘法(请参阅this issue,我们必须定义自己的乘法。


一个小问题,如果要比较两个卷积运算,也很重要,如下:

FFT在第一个元素(图像的左上像素)中获取其输入的原点。为了避免输出偏移,您需要生成一个填充的内核,其中内核的原点是左上角的像素。这很棘手,实际上...

您当前的代码:

fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
conv_fil = fil.unsqueeze(0).unsqueeze(0)
padded_fil = F.pad(conv_fil, (0, gray_im.shape[0]-fil.shape[0], 0, gray_im.shape[1]-fil.shape[1]))

生成填充的内核,其原点位于像素(1,1)中,而不是像素(0,0)中。它需要在每个方向上移动一个像素。 NumPy具有对此有用的函数roll,我不知道Torch等效项(我对Torch一点都不熟悉)。这应该工作:

fil = torch.tensor([[1/9,1/9,1/9],[1/9,1/9,1/9],[1/9,1/9,1/9]])
padded_fil = fil.unsqueeze(0).unsqueeze(0).numpy()
padded_fil = np.pad(padded_fil, ((0, gray_im.shape[0]-fil.shape[0]), (0, gray_im.shape[1]-fil.shape[1])))
padded_fil = np.roll(padded_fil, -1, axis=(0, 1))
padded_fil = torch.from_numpy(padded_fil)

最后,将您的fftshift函数应用于空间域图像,会导致频域图像(应用于图像的FFT的结果)发生偏移,使得原点位于图像的中间,而不是左上方。在查看FFT的输出时,此偏移很有用,但在计算卷积时毫无意义。


将这些东西放在一起,卷积现在是:

def complex_multiplication(t1, t2):
  real1, imag1 = t1[:,:,0], t1[:,:,1]
  real2, imag2 = t2[:,:,0], t2[:,:,1]
  return torch.stack([real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2], dim = -1)

fft_im = torch.rfft(gray_im, 2, onesided=False)
fft_fil = torch.rfft(padded_fil, 2, onesided=False)
fft_conv = torch.irfft(complex_multiplication(fft_im, fft_fil), 2, onesided=False)

请注意,您可以执行单侧FFT节省一些计算时间:

fft_im = torch.rfft(gray_im, 2, onesided=True)
fft_fil = torch.rfft(padded_fil, 2, onesided=True)
fft_conv = torch.irfft(complex_multiplication(fft_im, fft_fil), 2, onesided=True, signal_sizes=gray_im.shape)

此处,频域的大小约为整个FFT的一半,但仅剩下多余的部分。卷积的结果不变。

© www.soinside.com 2019 - 2024. All rights reserved.