如何在具有多个二元分类输出的神经网络上进行交叉验证?

问题描述 投票:0回答:0

我正在尝试使用 StratifiedKFold 对输出多个二元分类的 CNN 进行交叉验证。然而,StratifiedKFold 无法处理多标签指标。

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

fold_scores = []
confusion_matrices = []

for fold, (train_index, test_index) in enumerate(skf.split(X, Y)):
    print(f'Fold {fold+1}:')
    X_train, X_test = X[train_index], X[test_index]
    Y_train, Y_test = Y[train_index], Y[test_index]

    # Build the model
    model = create_model()

    # Train the model
    model.fit(X_train, Y_train, epochs=10, batch_size=32)

    # Evaluate the model on the test data
    y_pred = model.predict_classes(X_test)
    y_pred_binary = (y_pred > 0.5).astype(int) # convert probabilities to binary predictions
    confusion_matrices.append([confusion_matrix(Y_test[i:], y_pred_binary[i:]) for i in range(4)])

    # Convert Y_test back to multilabel-indicator format for evaluation
    scores = model.evaluate(X_test, Y_test, verbose=0)
    fold_scores.append(scores)

这是错误:

ValueError: Supported target types are: ('binary', 'multiclass'). Got 'multilabel-indicator' instead. 

除了 StratifiedKFold 之外,还有其他方法可以对具有不平衡数据集的多二元分类模型进行交叉验证吗?我的 CNN 使用 Keras 和 Tensorflow。

python keras scikit-learn deep-learning cross-validation
© www.soinside.com 2019 - 2024. All rights reserved.