我试图在任意点插值函数,并且我在切比雪夫极点处有函数值。我使用快速傅里叶变换的实数值来计算切比雪夫系数。然后我用 2/N 缩放它们,然后使用多项式库来评估一组点处的切比雪夫多项式系列。这会产生错误的函数近似。我哪里错了?
import numpy as np
import matplotlib.pyplot as plt
# Define the number of
# Chebyshev extreme points
N = 10
# Define the function to be
# approximated
def f(x):
return x**2
# Evaluate the function at the
# Chebyshev extreme points
x = np.cos(np.arange(N) * np.pi / N)
y = f(x)
# Compute the discrete Fourier
# transform (DFT) of the function
# values using the FFT algorithm
DFT = np.fft.fft(y).real
# Compute the correct scaling
# factor
scaling_factor = 2/N
# Scale the DFT coefficients by
# the correct scaling factor
chebyshev_coefficients = scaling_factor * DFT
# Use Chebval to
# evaluate the approximated
# polynomial at a set of points
x_eval = np.linspace(-1, 1, 100)
y_approx = np.polynomial.chebyshev.chebval(x_eval, chebyshev_coefficients[::-1])
# Plot the original function
# and the approximated function
plt.plot(x, y, 'o',
label='Original function')
plt.plot(x_eval, y_approx, '-',
label='Approximated function')
plt.legend()
plt.show()
你犯了几个错误。
这是代码的更新版本,运行良好。
import numpy as np
import matplotlib.pyplot as plt
# Define the number of
# Chebyshev extreme points
N = 11
# Define the function to be
# approximated
def f(x):
return x**2
# Evaluate the function at the
# Chebyshev extreme points
x = np.cos(np.arange(N) * np.pi / (N - 1))
y = f(x)
# Compute the discrete Fourier
# transform (DFT) of the function
# values using the FFT algorithm
DFT = np.fft.fft(np.hstack((y, y[N - 2: 0: -1]))).real / (2 * N - 2)
# Scale the DFT coefficients by
# the correct scaling factor
chebyshev_coefficients = DFT[:N] * 2
chebyshev_coefficients[0] /= 2
if N % 2 != 0:
chebyshev_coefficients[-1] /= 2
# Use Chebval to
# evaluate the approximated
# polynomial at a set of points
x_eval = np.linspace(-1, 1, 100)
y_approx = np.polynomial.chebyshev.chebval(x_eval, chebyshev_coefficients)
# Plot the original function
# and the approximated function
plt.plot(x, y, 'o',
label='Original function')
plt.plot(x_eval, y_approx, '-',
label='Approximated function')
plt.legend()
plt.show()