如何提高测试准确度?

问题描述 投票:0回答:1

我是机器学习新手!我目前正在使用决策树分类器来解决我的文本分类问题,作为我的模型,我的训练准确度为 98%,但测试准确度仅为 48%,如何提高测试准确度?

型号:

model = DecisionTreeClassifier()
model.fit(X_train , Y_train)
python machine-learning cross-validation decision-tree training-data
1个回答
0
投票

这是一个很好的问题。就您而言,您的模型过度拟合。这意味着你的训练数据集不好或者你没有微调模型的超参数。要微调超参数,请尝试使用网格搜索算法。这是一个例子

n_estimators_range = list(range(10, 201, 10))  # from 10 to 200, step 10 can be modified
max_depth_range = list(range(1, 12, 1))  # from 1 to 12, step 1
min_samples_split_range = list(range(2, 21, 2))  # from 2 to 20, step 2
min_samples_leaf_range = list(range(1, 11))  # from 1 to 10, step 1

# Define hyperparameter for grid search
param_grid = {
    'n_estimators': n_estimators_range,
    'max_depth': max_depth_range, # affect you model the most
    'min_samples_split': min_samples_split_range,
    'min_samples_leaf': min_samples_leaf_range,
} # you can modify it and increase number of hyper parameters

# Create the GridSearchCV object
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid,
                           cv=3, n_jobs=-1, verbose=2, scoring='accuracy') # scoring can be changed same as cv, n_jobs and verbose

# Grid search your model
grid_search.fit(X_train, y_train)

# Get the best hyperparameters and print them
best_params = grid_search.best_params_
print(f"Best hyperparameters: {best_params}")

您还可以将所有网格搜索结果保存在 csv 中以便分析结果。从该 csv 您可以计算每一行的 F1 分数。请记住,如果您的初始训练确实花费了 2 分钟以上,那么您需要 2-5 小时来查找这些超参数。因此,请注意步长和使用的超参数数量。我希望它有帮助

© www.soinside.com 2019 - 2024. All rights reserved.