我有两个变量,我尝试在 scipy optimization 中使用
curve_fit
来拟合数据。看起来不错,但左侧部分的红线与数据(绿点)不太吻合。如何在 curve_fit()
上施加一些权重,以将左侧的红线移向蓝线?
这是代码:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from pandas import DataFrame
x = [ 57, 83, 124, 141, 196, 223, 275, 302, 341, 714, 895,
1034, 1117, 1207, 1248, 1416, 1494, 1563, 1708, 1785, 1863, 2015,
2139, 2238, 2312, 2412, 2442, 2520, 2596, 2658, 2706, 2777, 2846,
2966, 3106, 3241, 3276, 3424, 3568, 3647, 3831, 3961, 4091, 4248,
4430, 4478, 4644, 4833, 5052, 6041 ]
y = [ 70, 81, 87, 91, 96, 106, 109, 114, 120, 129, 144, 162, 168,
175, 181, 184, 190, 195, 205, 213, 216, 219, 224, 226, 231, 236,
239, 247, 255, 260, 264, 269, 282, 292, 297, 304, 308, 313, 319,
322, 327, 333, 338, 341, 345, 354, 362, 364, 374, 391 ]
plt.scatter(x,y,color='green')
def func(x, a, b):
return a * np.power(x,b)
popt, pcov = curve_fit(func, x, y)
plt.plot(x, func(x, *popt), 'b-', label='fit: a=%5.3f, b=%5.3f' % tuple(popt))
popt2 = [12.6, 0.386]
plt.plot(x, func(x, *popt2), 'r-', label='fit: a=%5.3f, b=%5.3f' % tuple(popt2))
plt.semilogx()
您可以使用
sigma
中的参数curve_fit
。来自文档:
sigma:无或 M 长度序列或 MxM 数组,可选
确定 ydata 的不确定性。如果我们将残差定义为,那么 sigma 的解释取决于它的维数:r = ydata - f(xdata, *popt)
一维西格玛应包含 ydata 中误差的标准差值。在这种情况下,优化函数是。chisq = sum((r / sigma) ** 2)
二维 sigma 应包含 ydata 中误差的协方差矩阵。在这种情况下,优化函数是。chisq = r.T @ inv(sigma) @ r
因此您可以将一维
sigma
视为逆权重。为了更好地拟合曲线的特定部分,请将较低的西格玛值分配给特定点:
plt.scatter(x,y,color='green')
def func(x, a, b):
return a * np.power(x,b)
sigma = np.ones(len(x))
sigma[10:] *= 10 # set higher sigma for all data points other than the first 10
popt, pcov = curve_fit(func, x, y, sigma=sigma)
plt.plot(x, func(x, *popt), 'b-', label='fit: a=%5.3f, b=%5.3f' % tuple(popt))
popt2 = [12.6, 0.386]
plt.plot(x, func(x, *popt2), 'r-', label='fit: a=%5.3f, b=%5.3f' % tuple(popt2))
plt.semilogx()
您可以使用
sigma
来获得比上面更好的结果。
看起来没有必要添加权重,但是模型可以通过添加更多自由度来改进:
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import curve_fit, Bounds
x = np.array([
57, 83, 124, 141, 196, 223, 275, 302, 341, 714, 895,
1034, 1117, 1207, 1248, 1416, 1494, 1563, 1708, 1785, 1863, 2015,
2139, 2238, 2312, 2412, 2442, 2520, 2596, 2658, 2706, 2777, 2846,
2966, 3106, 3241, 3276, 3424, 3568, 3647, 3831, 3961, 4091, 4248,
4430, 4478, 4644, 4833, 5052, 6041,
])
y = np.array([
70, 81, 87, 91, 96, 106, 109, 114, 120, 129, 144, 162, 168,
175, 181, 184, 190, 195, 205, 213, 216, 219, 224, 226, 231, 236,
239, 247, 255, 260, 264, 269, 282, 292, 297, 304, 308, 313, 319,
322, 327, 333, 338, 341, 345, 354, 362, 364, 374, 391,
])
def model(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
return a * (x - b)**c + d
popt, *_ = curve_fit(
f=model, xdata=x, ydata=y,
p0=(1, 0, 2, 0),
bounds=Bounds(
lb=(0, -1e4, 0, -1e3),
ub=(1e4, 1e4, 50, 1e3),
),
)
print(popt)
fig, ax = plt.subplots()
ax.semilogx(x, y, label='experiment')
ax.semilogx(x, model(x, *popt), label='fit')
ax.legend()
plt.show()
[ 1.75236837e+02 -2.67170725e+03 2.30138464e-01 -9.99999670e+02]