如何从GridSearchCV输出中可视化一个XGBoost树?

问题描述 投票:0回答:2

我使用的是 XGBRegressor 来拟合模型,使用 gridsearchcv. 我想visulaize的树木。

这里是我遵循的链接(如果重复)。如何从gridsearchcv绘制决策树?

xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)
folds = 5
grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4, verbose=3 )
model=grid.fit(X_train, y_train)

办法1:

 dot_data = tree.export_graphviz(model.best_estimator_, out_file=None, 
        filled=True, rounded=True, feature_names=X_train.columns)
 dot_data

 Error: NotFittedError: This XGBRegressor instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

方法2:

tree.export_graphviz(best_clf, out_file='tree.dot',feature_names=X_train.columns,leaves_parallel=True)
subprocess.call(['dot', '-Tpdf', 'tree.dot', '-o' 'tree.pdf'])

同样的错误。

python-3.x plot scikit-learn decision-tree xgboost
2个回答
1
投票

scikit-learn的 tree.export_graphviz 在这里是行不通的,因为你的 best_estimator_ 不是一棵树,而是整个树的集合。

下面是你如何使用XGBoost自己的 plot_tree 和波士顿的住房数据。

from xgboost import XGBRegressor, plot_tree
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import load_boston
import matplotlib.pyplot as plt

X, y = load_boston(return_X_y=True)

params = {'learning_rate':[0.1, 0.5], 'n_estimators':[5, 10]} # dummy, for demonstration only

xgb = XGBRegressor(learning_rate=0.02, n_estimators=600,silent=True, nthread=1)
grid = GridSearchCV(estimator=xgb, param_grid=params, scoring='neg_mean_squared_error', n_jobs=4)

grid.fit(X, y)

我们最好的估计是:

grid.best_estimator_
# result (details may be different due to randomness):
XGBRegressor(base_score=0.5, booster='gbtree', colsample_bylevel=1,
             colsample_bynode=1, colsample_bytree=1, gamma=0,
             importance_type='gain', learning_rate=0.5, max_delta_step=0,
             max_depth=3, min_child_weight=1, missing=None, n_estimators=10,
             n_jobs=1, nthread=1, objective='reg:linear', random_state=0,
             reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
             silent=True, subsample=1, verbosity=1)

做完这些后,利用来自于... 这条线 来绘制,比如说,树#4。

fig, ax = plt.subplots(figsize=(30, 30))
plot_tree(grid.best_estimator_, num_trees=4, ax=ax)
plt.show()

enter image description here

同样,对于1号树:

fig, ax = plt.subplots(figsize=(30, 30))
plot_tree(grid.best_estimator_, num_trees=1, ax=ax)
plt.show()

enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.