如何将每个迭代的Statsmodel保存为一个文件以备后用?

问题描述 投票:0回答:2

我生成了下表:

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

# Generate 'random' data
np.random.seed(0)
X = 2.5 * np.random.randn(10) + 1.5
res = 0.5 * np.random.randn(10)
y = 2 + 0.3 * X + res
Name = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']

# Create pandas dataframe to store our X and y values
df = pd.DataFrame(
    {'Name': Name,
     'X': X,
     'y': y})

# Show the dataframe
df

得到下表:

姓名 X y
A 5.910131 3.845061
2.500393 3.477255
C 3.946845 3.564572
D 7.102233 4.191507
E 6.168895 4.072600
F -0.943195 1.883879
G 3.875221 3.909606
H 1.121607 2.233903
1.241953 2.529120
J 2.526496 2.330901

以下代码迭代一次排除一行,并构建一组回归图:

import statsmodels.formula.api as smf
import warnings
warnings.filterwarnings('ignore')
# Initialise and fit linear regression model using `statsmodels`

for row_index, row in df.iterrows():
    # dataframe with all rows except for one
    df_reduced = df[~(df.index == row_index)]
    model = smf.ols('X ~ y', data=df_reduced)
    model = model.fit()
    intercept, slope = model.params
    print(model.summary())

    y1 = intercept + slope * df_reduced.y.min()
    y2 = intercept + slope * df_reduced.y.max()
    plt.plot([df_reduced.y.min(), df_reduced.y.max()], [y1, y2], label=row.Name, color='red')
    plt.scatter(df_reduced.y, df_reduced.X)
    plt.legend()
    plt.savefig(f"All except {row.Name} analogue.pdf")
    plt.show()

问题是,如何将正在生成的每个模型保存为一个文件,以便以后使用?在本示例中,应该至少生成 9 个回归模型。我想让它们每个都作为一个文件,也可以用一个名字来识别。

第二个问题是,如何在 matplotlib 的视觉生成中的每个模型摘要和绘图之间添加一个空格。

python pandas matplotlib regression statsmodels
2个回答
0
投票

你只需要添加这个:

model.save(f"model_{row_index}.pkl")
在你的循环中


0
投票

存储训练好的模型: 假设你有一些命名程序可用于每个模型文件 mf,你可以使用 pickle 存储模型。

import statsmodels.api as sm
import pickle

# Train your model
model = sm.OLS(y, X).fit()

# Save the model to a file
with open('model.pickle', 'wb') as f:
    pickle.dump(model, f)

# Load the model from the file
with open('model.pickle', 'rb') as f:
    loaded_model = pickle.load(f)

print(loaded_model.summary())

给出以下输出。

                                     OLS Regression Results                                
=======================================================================================
Dep. Variable:                      y   R-squared (uncentered):                   0.525
Model:                            OLS   Adj. R-squared (uncentered):              0.472
Method:                 Least Squares   F-statistic:                              9.931
Date:                Mon, 03 Apr 2023   Prob (F-statistic):                      0.0117
Time:                        12:42:57   Log-Likelihood:                         -20.560
No. Observations:                  10   AIC:                                      43.12
Df Residuals:                       9   BIC:                                      43.42
Df Model:                           1                                                  
Covariance Type:            nonrobust                                                  
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
x1             0.8743      0.277      3.151      0.012       0.247       1.502
==============================================================================
Omnibus:                        1.291   Durbin-Watson:                   0.989
Prob(Omnibus):                  0.524   Jarque-Bera (JB):                0.937
Skew:                           0.637   Prob(JB):                        0.626
Kurtosis:                       2.209   Cond. No.                         1.00
==============================================================================

Notes:
[1] R² is computed without centering (uncentered) since the model does not contain a constant.
[2] Standard Errors assume that the covariance matrix of the errors is correctly specified.

请注意,出于简化目的,模型导入与您的略有不同。但是,您应该能够以相同的方式存储和加载模型。

我不完全确定我是否正确理解了您关于输出和绘图间距的问题。

间隔摘要: 也许只需添加 emtpy print() 语句?

间隔地块: 您每次都在生成全新的图,因此我不明白这个问题。请随时提供更多信息,我会尽快回复您。

© www.soinside.com 2019 - 2024. All rights reserved.