一个简单的pandas图可以在图例上使用圆形标记生成预期输出:
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
matplotlib.pyplot.show()
请注意,“CrudeRate”图例项是带有圆圈标记的直线,该标记是正确的。
但是,如果我为Holt线性指数平滑函数添加一些额外的图,则图例会丢失圆圈标记:
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
fit.fittedvalues.plot(color=color, style="--")
title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
forecast = fit.forecast(predict).rename("${}$".format(title))
forecast.plot(color=color, legend=True, style="--")
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.show()
请注意,“CrudeRate”图例项目只是一条没有圆圈标记的直线。
是什么导致第二种情况下的图例丢失主要情节的圆圈标记?
在matplotlib.pyplot.legend()
之前使用matplotlib.pyplot.show()
将解决您的问题。
由于您正在绘制3个图形,并且根据我的理解,您只需要图例中的2个标签,我们将label='_nolegend_'
传递给fit.fittedvalues.plot()
。如果我们不这样做,我们将在图例中使用第三个标签,其值为None
。
import io
import pandas
import matplotlib
import statsmodels
import matplotlib.pyplot
import statsmodels.tsa.api
cause = "Malignant neoplasms"
csv_data = """Year,CrudeRate
1999,197.0
2000,196.5
2001,194.3
2002,193.7
2003,192.0
2004,189.2
2005,189.3
2006,187.6
2007,186.9
2008,186.0
2009,185.0
2010,186.2
2011,185.1
2012,185.6
2013,185.0
2014,185.6
2015,185.4
2016,185.1
2017,183.9
"""
def ets_non_seasonal(df, color, predict, exponential=False, damped=False, damping_slope=0.98):
fit = statsmodels.tsa.api.Holt(df, exponential=exponential, damped=damped).fit(damping_slope=damping_slope if damped else None)
fit.fittedvalues.plot(color=color, style="--", label='_nolegend_')
title = "ETS(A,{}{},N)".format("M" if exponential else "A", "_d" if damped else "")
forecast = fit.forecast(predict).rename("${}$".format(title))
forecast.plot(color=color, legend=True, style="--")
df = pandas.read_csv(io.StringIO(csv_data), index_col="Year", parse_dates=True)
df.plot(color="black", marker="o", legend=True)
ets_non_seasonal(df, "red", 5, exponential=False, damped=False, damping_slope=0.98)
matplotlib.pyplot.legend()
matplotlib.pyplot.show()
另外,为了让您更容易编写代码,最好按照matplotlib.pyplot
导入import matplotlib.pyplot as plt
。