我一直在尝试对自闭症进行分类并拥有一个 CNN 模型。迄今为止论文中的最佳准确率约为 70-73%~,而我的模型在不同参数下的准确率约为 65-70%。我终于找到了一种超参数组合,在使用测试集(大约 10% 的数据集,10% 用于验证,80% 用于训练)进行测试时,准确率达到 70% 以上。我决定进行 10 倍交叉验证,并针对每个时期使用 verbose 1 进行检查。第一次运行每个时期的验证准确率约为 68-76%(总共 25 个时期),得分为 72%。然而,从第二批 25 个 epoch 来看,val 准确率在 98-100% 左右,准确率一直保持在 1.000。第三批类似,100%弹出。这是正常的吗?到目前为止我还没有使用过这个,我使用的代码是 CNN k-Fold 交叉验证的模板。
from sklearn.model_selection import KFold
import numpy as np
# data should be of shape (838, 392, 392, num_channels)
data = conn_matrices
# labels should be of shape (838,)
labels = y
# Initialize 10-fold cross-validation
kf = KFold(n_splits=10, shuffle=True, random_state=42)
# Create lists to store the results of each fold
fold_accuracies = []
# Perform cross-validation and store the results
for train_index, test_index in kf.split(data):
X_train, X_test = data[train_index], data[test_index]
y_train, y_test = labels[train_index], labels[test_index]
# Define and compile your Keras-based CNN model
# Replace 'your_cnn_model' with your actual model
your_cnn_model = model
# Train the model on the training data
your_cnn_model.fit(X_train, y_train, epochs=25,
batch_size=32, validation_data=(X_test, y_test), verbose=1)
# Evaluate the model on the test data
accuracy = your_cnn_model.evaluate(X_test, y_test)[1]
fold_accuracies.append(accuracy)
# Print the accuracy of each fold
for i, accuracy in enumerate(fold_accuracies):
print(f"Fold {i+1} Accuracy: {accuracy:.4f}")
# Calculate and print the mean accuracy and standard deviation of the results
mean_accuracy = np.mean(fold_accuracies)
std_deviation = np.std(fold_accuracies)
print(f"Mean Accuracy: {mean_accuracy:.4f}")
print(f"Standard Deviation: {std_deviation:.4f}")
预计每次运行的准确率相似,约为 70% 到最大 76-77%
您在训练时为模型提供测试数据,这可能会使用测试数据来拟合某些模型参数/超参数,因此当然会过度拟合,并在对已知的相同数据进行测试时给出过于乐观的分数:
# Train the model on the training data
your_cnn_model.fit(X_train, y_train, epochs=25, batch_size=32, validation_data=(X_test, y_test), verbose=1)
您需要使用嵌套交叉验证来查找超参数: https://scikit-learn.org/stable/auto_examples/model_selection/plot_nested_cross_validation_iris.html