为什么具有特定参数的经过训练的
RandomForestClassifier
无法与使用 GridSearchCV
改变这些参数的性能相匹配?
def random_forest(X_train, y_train):
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import roc_auc_score, make_scorer
from sklearn.model_selection import train_test_split
X_train, X_validate, y_train, y_validate = train_test_split(X_train, y_train, random_state=0)
# various combinations of max depth and max features
max_depth_vals = [1,2,3]
max_features_vals = [2,3,4]
grid_values = {'max_depth': max_depth_vals, 'max_features': max_features_vals}
# build GridSearch
clf = RandomForestClassifier(n_estimators=10)
grid = GridSearchCV(clf, param_grid=grid_values, cv=3, scoring='roc_auc')
grid.fit(X_train, y_train)
y_hat_proba = grid.predict_proba(X_validate)
print('Train Grid best parameter (max. AUC): ', grid.best_params_)
print('Train Grid best score (AUC): ', grid.best_score_)
print('Validation set AUC: ', roc_auc_score(y_validate, y_hat_proba[:,1]))
# build RandomForest with hard coded values. AUC should be ballpark to grid search
clf = RandomForestClassifier(max_depth=3, max_features=4, n_estimators=10)
clf.fit(X_train, y_train)
y_hat = clf.predict(X_validate)
y_hat_prob = clf.predict_proba(X_validate)[:, 1]
auc = roc_auc_score(y_hat, y_hat_prob)
print("\nMax Depth: 3 Max Features: 4\n---------------------------------------------")
print("auc: {}".format(auc))
return
结果 - 网格搜索确定
max_depth=3
和 max_features=4
的最佳参数,并计算 auc score
为 0.85;当我将其通过带有保留验证集的代码时,我得到 auc score
为 0.84。然而,当我直接使用这些参数对分类器进行编码时,它计算出的 auc score
为 1.0。我的理解是,它应该在同一个范围内~0.85,但这感觉很遥远。
Validation set AUC: 0.8490471073563559
Grid best parameter (max. AUC): {'max_depth': 3, 'max_features': 4}
Grid best score (AUC): 0.8599727094965482
Max Depth: 3 Max Features: 4
---------------------------------------------
auc: 1.0
我可能会误解概念,无法正确应用技术,甚至存在编码问题。谢谢。
你必须传递预测的概率,而不是预测的标签:
y_hat_prob = clf.predict_proba(X_validate)[:, 1]
auc = roc_auc_score(y_validate, y_hat_prob)
请参阅示例:https://scikit-learn.org/stable/modules/ generated/sklearn.metrics.roc_auc_score.html