无法进行网格搜索和训练模型

问题描述 投票:0回答:1

我正在研究基本的文本分类问题,我想使用堆叠分类器以及对基本分类器的参数进行一些微调以获得高精度结果。

我的数据集有 8000 行和 2 列(文本和类)。下面的代码似乎被卡住了,我不熟悉该领域(初学者)来发现问题。

import pandas as pd
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import NuSVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score, log_loss, classification_report, confusion_matrix

# Define parameter grids for classifiers
param_grid_nusvc = {
    'nu': [0.1, 0.3, 0.5, 0.7, 0.9],
    'kernel': ['linear', 'rbf'],
}

param_grid_logreg = {
    'C': [0.1, 1, 10],
    'penalty': ['l1', 'l2'],
}

# Perform grid search for classifiers with improved clarity
nusvc_grid_search = GridSearchCV(NuSVC(probability=True), param_grid_nusvc, cv=2, scoring='accuracy')  # Use accuracy scoring
logreg_grid_search = GridSearchCV(LogisticRegression(), param_grid_logreg, cv=2, scoring='accuracy')

nusvc_grid_search.fit(X_train, y_train)
logreg_grid_search.fit(X_train, y_train)

# Get best parameters
best_params_nusvc = nusvc_grid_search.best_params_
best_params_logreg = logreg_grid_search.best_params_

# Set up base classifiers with best parameters
best_nusvc = NuSVC(probability=True, **best_params_nusvc)
best_logreg = LogisticRegression(**best_params_logreg)

# Setting up stacking classifier
sc = StackingClassifier(
    estimators=[
        ('NuSVC', best_nusvc),
        ('LDA', LinearDiscriminantAnalysis())
    ],
    final_estimator=best_logreg
)

sc.fit(X_train, y_train)

# Evaluate the combined classifiers
print('****Results****')
train_predictions = sc.predict(X_test)
acc = accuracy_score(y_test, train_predictions)
print("Accuracy: {:.4%}".format(acc))

train_predictions_proba = sc.predict_proba(X_test)
ll = log_loss(y_test, train_predictions_proba)
print("Log Loss: {}".format(ll))

# Print classification report (optional)
print('\nClassification Report:')
print(classification_report(y_test, train_predictions))

# Print confusion matrix (optional)
print('\nConfusion Matrix:')
print(confusion_matrix(y_test, train_predictions))

上面的一些更改是根据 chatGPT 的建议进行的,以指导我如何使用网格搜索进行微调。代码似乎卡住了(大约 20 分钟)。如果没有网格搜索,它似乎可以轻松运行大约 2-3 分钟。

python machine-learning scikit-learn nlp text-classification
1个回答
0
投票

您的 SVC 网格有 5×2 个点,每个点适合 2 倍,因此需要大约 20 倍的时间。您可以在搜索中设置

verbose=4
以更好地跟踪正在发生的情况,并考虑并行化(例如
n_jobs=-1
)。

© www.soinside.com 2019 - 2024. All rights reserved.