使用 GroupKFold 实现交叉验证时出现关键错误

问题描述 投票:0回答:1

我有一个 df,有 3 个主要列“标签”、“嵌入”(特征)、“chr”。我试图通过对染色体进行分组来进行 10 倍交叉验证,以使 chr1 行全部位于训练或测试中(不跨训练/测试分割)。 我有一个 df 看起来像:

我相信我在代码中做得正确,但我不断遇到这个关键错误:

这是我的代码:

import numpy as np
from sklearn.model_selection import GroupKFold

X = np.array([np.array(x) for x in mini_df['embeddings']])
y = mini_df['label']
groups = mini_df['chromosome']
group_kfold = GroupKFold(n_splits=10)

# Initialize figure for plotting
plt.figure(figsize=(10, 6))

# Perform cross-validation and plot ROC curves for each fold
for i, (train_idx, val_idx) in enumerate(group_kfold.split(X, y, groups)):
    X_train_fold, X_val_fold = X[train_idx], X[val_idx]
    y_train_fold, y_val_fold = y[train_idx], y[val_idx]
    
    # Initialize classifier
    rf_classifier = RandomForestClassifier(n_estimators=n_trees, random_state=42, max_depth=max_depth, n_jobs=-1)
    
    # Train the classifier on this fold
    rf_classifier.fit(X_train_fold, y_train_fold)
    
    # Make predictions on the validation set
    y_pred_proba = rf_classifier.predict_proba(X_val_fold)[:, 1]
    
    # Calculate ROC curve
    fpr, tpr, _ = roc_curve(y_val_fold, y_pred_proba)
    
    # Calculate AUC
    roc_auc = auc(fpr, tpr)
    
    # Plot ROC curve for this fold
    plt.plot(fpr, tpr, lw=1, alpha=0.7, label=f'ROC Fold {i+1} (AUC = {roc_auc:.2f})')

# Plot ROC for random classifier
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Random', alpha=0.8)

# Add labels and legend
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves for Random Forest Classifier')
plt.legend(loc='lower right')
plt.show()
python pandas dataframe scikit-learn cross-validation
1个回答
0
投票

键错误意味着值字典中缺少。在这种情况下,第 15 行中名为

y
的字典不包含等于
train_idx
val_idx
(或两者;因为您截断了图像中的错误消息而无法辨别)的键。

要找出问题所在,您可以执行以下操作:

... assert y[train_idx], f"y does not have a {train_idx} key value: {iter(y)}" assert y[val_idx], f"y does not have a {val_idx} key value: {iter(y)}" X_train_fold, X_val_fold = X[train_idx], X[val_idx] y_train_fold, y_val_fold = y[train_idx], y[val_idx] ...
    
© www.soinside.com 2019 - 2024. All rights reserved.