我正在研究用于快速多项式乘法的 FFT 算法。我们检查了算法,我决定尝试用 Python 实现它。
from typing import List
import numpy as np
def fft(p: List[int]) -> List[int]:
n = len(p)
if n == 1:
return p
unity_root = np.exp(2j * np.pi / n)
p_even = p[::2]
p_odd = p[1::2]
y_even = fft(p_even)
y_odd = fft(p_odd)
y = [0] * n
for j in range(n // 2):
omega = np.power(unity_root, j)
y[j] = y_even[j] + omega * y_odd[j]
y[n // 2 + j] = y_even[j] - omega * y_odd[j]
return y
def ifft(p: List[int]) -> List[int]:
n = len(p)
if n == 1:
return p
unity_root = (1 / n) * np.exp(-2j * np.pi / n)
p_even = p[::2]
p_odd = p[1::2]
y_even = ifft(p_even)
y_odd = ifft(p_odd)
y = [0] * n
for j in range(n // 2):
omega = np.power(unity_root, j)
y[j] = y_even[j] + omega * y_odd[j]
y[n // 2 + j] = y_even[j] - omega * y_odd[j]
return y
我尝试运行以下代码以确保它有效
print(ifft(fft([1, 2, 3, 4])))
我希望输出是我开始使用的原始列表,因为该列表代表系数,但我得到(忽略浮点运算的精度问题):
[(4+0j), (11-0j), (12+0j), (13+0j)]
我的问题是:
我不应该得到原始列表吗?如果我应该得到原始列表,代码中的错误在哪里?当我多次检查代码时,我在查找它时遇到了问题。如果我不应该取回原始列表并且我的代码是正确的,那么我实际上得到了什么?