我想将一组 XY 数据点拟合到 Python 3.8 中的倒数指数函数。我正在尝试使用
leastsq
库的 scipy.optimize
函数。
我希望将我的数据拟合到类似于以下的函数:
其中
d
、s
和 e
是我希望适合我的数据的所有函数参数。
我有一些示例数据,我想将上面的函数拟合到其中,以及以下测试脚本:
import numpy as np
from scipy.optimize import leastsq
import matplotlib.pyplot as plt
def main():
# Declare example data
dataX = [0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5, 11, 11.5, 12, 12.5, 13, 13.5, 14.5, 15, 15.5, 16, 16.5, 17]
dataY = [0.00000000e+00, 2.47076895e-03, 9.66638150e-03, 1.97670203e-02, 3.42218835e-02, 4.97943540e-02, 6.54004261e-02, 7.93484613e-02, 8.61083796e-02, 8.49273950e-02, 6.70327841e-02, 3.45946900e-02, -2.24559129e-02, -1.17023372e-01, -2.49244700e-01, -4.27837601e-01, -6.61529080e-01, -9.62240010e-01, -1.35215046e+00, -1.84177480e+00, -2.46428273e+00, -3.25198094e+00, -4.24238794e+00, -5.53021367e+00, -7.23278230e+00, -9.57603920e+00, -1.31432215e+01, -1.99465466e+01, -1.45862198e+01, -1.01563667e+01, -7.19977087e+00, -5.03349403e+00, -3.37202972e+00, -2.07966002e+00]
# Fit to reciprocal exponential function
guessDenom = 1
guessPow = 2
guessShift = -10
optimalFx = lambda args: -(args[0] / pow(dataX + args[2], args[1])) - dataY
finalDenom, finalPow, finalShift = leastsq(optimalFx, [guessDenom, guessPow, guessShift])[0]
dataYFitted = []
for x in dataX:
dataYFitted.append(-finalDenom / pow(x + finalShift, finalPow))
# Plot original & fitted data
plt.xlim([0, 17])
plt.ylim([-30, 30])
plt.plot(dataX, dataYFitted, label = "Fitted Function", color = "dodgerblue")
plt.plot(dataX, dataY, label = "Original Data", color = "gray", alpha = 0.5)
plt.legend()
plt.show()
plt.close()
plt.clf()
if (__name__ == "__main__"):
main()
上面的脚本采用示例数据,并尝试使用 lambda 表达式来拟合上面的函数。它还以离散 X 间隔沿着拟合函数绘制原始数据。我之前曾在其他函数中使用过这种方法,并且它按预期工作,但是对于这个倒数指数函数,它似乎无法正常工作。
首先,运行上面的脚本后,我收到以下警告:
RuntimeWarning: divide by zero encountered in divide
optimalFx = lambda args: -(args[0] / pow(dataX + args[2], args[1])) - dataY
RuntimeWarning: invalid value encountered in power
optimalFx = lambda args: -(args[0] / pow(dataX + args[2], args[1])) - dataY
RuntimeWarning: Number of calls to function has reached maxfev = 800.
warnings.warn(errors[info][0], RuntimeWarning)
由于假设的拟合函数甚至不可见,因此显然出了问题。
我假设这里发生的是这样一个事实:在库中的某个点试图找到一个合适的值时,它遇到了被 0 除的情况,并且失败了。如果简单地绘制倒数指数函数,这一点就很明显:
在某个时刻,一定会遇到
ind
的答案。
如何使用 Python 中的
scipy.optimize
库将倒数指数函数拟合到数据集而不遇到这个问题?
也许有一种方法可以限制提供给 leastsq
函数的每个参数的检查域?
有很多方法可以解决这个问题。一种简单但成功的方法是扰动渐近线处的任何 x 值,并将 y 值限制为某个较高但有限的值。另请注意,您的第一个参数是分子而不是分母,并且您可能需要在指数之前添加
abs
。
from typing import Sequence
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
def optimal_fx(x: np.ndarray, num: float, pow: float, shift: float) -> np.ndarray:
i_close = x.searchsorted(shift)
if abs(x[i_close] - shift) < 1e-2:
x = x.copy()
x[i_close] -= 1e-2 # perturb away from asymptote
return (
-num/np.abs(x - shift)**pow
).clip(min=-100)
def plot(
data_x: np.ndarray, data_y: np.ndarray, pguess: Sequence[float], pfinal: Sequence[float],
) -> plt.Figure:
fig, ax = plt.subplots()
ax: plt.Axes
hires_x = np.linspace(start=data_x.min(), stop=data_x.max(), num=500)
num, pow, shift = pfinal
ax.scatter(data_x, data_y, label="Original Data")
ax.plot(hires_x, optimal_fx(hires_x, *pguess), label="Guess")
ax.plot(hires_x, optimal_fx(hires_x, *pfinal), label="Fit")
ax.set_ylim(bottom=-50, top=10)
ax.set_title(f'y = -{num:.2f}/(x - {shift:.2f})^{pow:.2f}')
ax.legend()
return fig
def main():
x = np.array([
0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5, 9, 9.5, 10,
10.5, 11, 11.5, 12, 12.5, 13, 13.5, 14.5, 15, 15.5, 16, 16.5, 17])
y = np.array([
0.00000000e+00, 2.47076895e-03, 9.66638150e-03, 1.97670203e-02, 3.42218835e-02,
4.97943540e-02, 6.54004261e-02, 7.93484613e-02, 8.61083796e-02, 8.49273950e-02,
6.70327841e-02, 3.45946900e-02, -2.24559129e-02, -1.17023372e-01, -2.49244700e-01,
-4.27837601e-01, -6.61529080e-01, -9.62240010e-01, -1.35215046e+00, -1.84177480e+00,
-2.46428273e+00, -3.25198094e+00, -4.24238794e+00, -5.53021367e+00, -7.23278230e+00,
-9.57603920e+00, -1.31432215e+01, -1.99465466e+01, -1.45862198e+01, -1.01563667e+01,
-7.19977087e+00, -5.03349403e+00, -3.37202972e+00, -2.07966002e+00])
pguess = (10, 1.2, 13.5)
pfinal, *_ = curve_fit(
f=optimal_fx,
xdata=x,
ydata=y,
p0=pguess,
bounds=(
(0.01, 0.1, 5),
( 100, 10, 15),
),
)
plot(x, y, pguess, pfinal)
plt.show()
if __name__ == "__main__":
main()