我想用多项式函数拟合一组数据点(我通常使用numpy.polyfit来做),但是我想让用户以交互方式选择多项式的次数。这是我要执行的操作的一个示例:
import sys
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
from sklearn.metrics import mean_squared_error
fig, ax = plt.subplots()
x = np.arange(1,10,0.2)
y = np.sin(x)
ax.plot(x,y,'o',color='orange',markeredgewidth=0.3,markeredgecolor='k')
ax.set_xlim(0,10)
ax.set_ylim(-1.1,1.1)
def press(event):
#fig.clf()
fig.canvas.draw_idle()
sys.stdout.flush()
deg = int(event.key)
coeffs = np.polyfit(x,y,deg)
p = np.poly1d(coeffs)
rms = sqrt(mean_squared_error(y, p(x)))
fig.text(0.8,1.02, 'rms='+str(round(rms,4)), rotation=0, color='k',transform=ax.transAxes)
with open('prova.txt', 'w') as filehandle:
filehandle.write('#Coefficients for a n= '+str(deg)+' polynomial fit\n\n')
for listitem in coeffs:
filehandle.write('%s\n' % listitem)
ln = plt.plot(x,p(x),'-',color='green',linewidth=0.8,zorder=0)
fig.canvas.draw()
cid = fig.canvas.mpl_connect('key_press_event', press)
plt.show()
fig.canvas.mpl_disconnect(cid)
f = open('prova.txt','r')
cfs = loadtxt('prova.txt', usecols=(0),comments='#')
print(cfs)
这样,可以有效地拟合点,但是第一个图之后的图被过度绘制。如果我为fig.clf()删除'#',则代码会更新拟合度,但会取消点。
在开始时用None
创建全局变量
txt = None
ln = None
在press()
内部,您可以检查是否已经分配了plot
,text
和remove()
。
global txt
global ln
if txt:
txt.remove()
txt = fig.text(0.8,1.02, 'rms='+str(round(rms,4)), rotation=0, color='k',transform=ax.transAxes)
if ln:
ln[0].remove()
ln = plt.plot(x,p(x),'-',color='green',linewidth=0.8,zorder=0)
fig.canvas.draw()
完整代码
import sys
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
from sklearn.metrics import mean_squared_error
# --- functions ---
def press(event):
global txt
global ln
#fig.clf()
fig.canvas.draw_idle()
sys.stdout.flush()
deg = int(event.key)
coeffs = np.polyfit(x,y,deg)
p = np.poly1d(coeffs)
rms = sqrt(mean_squared_error(y, p(x)))
with open('prova.txt', 'w') as filehandle:
filehandle.write('#Coefficients for a n= '+str(deg)+' polynomial fit\n\n')
for listitem in coeffs:
filehandle.write('%s\n' % listitem)
if txt:
txt.remove()
txt = fig.text(0.8,1.02, 'rms='+str(round(rms,4)), rotation=0, color='k',transform=ax.transAxes)
if ln:
ln[0].remove()
ln = plt.plot(x,p(x),'-',color='green',linewidth=0.8,zorder=0)
fig.canvas.draw()
# --- main ---
txt = None
ln = None
fig, ax = plt.subplots()
x = np.arange(1,10,0.2)
y = np.sin(x)
ax.plot(x,y,'o',color='orange',markeredgewidth=0.3,markeredgecolor='k')
ax.set_xlim(0,10)
ax.set_ylim(-1.1,1.1)
cid = fig.canvas.mpl_connect('key_press_event', press)
plt.show()
fig.canvas.mpl_disconnect(cid)
f = open('prova.txt','r')
cfs = loadtxt('prova.txt', usecols=(0),comments='#')
print(cfs)