我正在尝试为多项式分类绘制ROC曲线,我总共有46个唯一类。我的代码是-
n_classes = 46
best_C =1000
#best_Kernel =rbf
best_gamma =0.0001
svc_model_grid_param = SVC(C=best_C, kernel="rbf", gamma= best_gamma, )
print(svc_model)
model_OVR_svc = OneVsRestClassifier(svc_model_grid_param)
y_score = model_OVR_svc.fit(X_train, y_train).decision_function(X_valid)
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_valid[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Plot of a ROC curve for a specific class
for i in range(n_classes):
plt.figure()
plt.plot(fpr[i], tpr[i], label='ROC curve (area = %0.2f)' % roc_auc[i])
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
但是它显示了一个错误,错误如下,为什么显示此错误以及如何处理此错误?
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-84-b8b16a042ddc> in <module>
26 roc_auc = dict()
27 for i in range(n_classes):
---> 28 fpr[i], tpr[i], _ = roc_curve(y_valid[:, i], y_score[:, i])
29 roc_auc[i] = auc(fpr[i], tpr[i])
30
C:\ProgramData\Anaconda3\lib\site-packages\pandas\core\series.py in __getitem__(self, key)
909 key = check_bool_indexer(self.index, key)
910
--> 911 return self._get_with(key)
912
913 def _get_with(self, key):
C:\ProgramData\Anaconda3\lib\site-packages\pandas\core\series.py in _get_with(self, key)
921 elif isinstance(key, tuple):
922 try:
--> 923 return self._get_values_tuple(key)
924 except Exception:
925 if len(key) == 1:
C:\ProgramData\Anaconda3\lib\site-packages\pandas\core\series.py in _get_values_tuple(self, key)
966
967 if not isinstance(self.index, MultiIndex):
--> 968 raise ValueError('Can only tuple-index with a MultiIndex')
969
970 # If key is contained, would have returned by now
ValueError: Can only tuple-index with a MultiIndex
[请先帮助我解决此错误,谢谢。
总是把完整的错误消息(Traceback)放在问题中,其中还有其他有用的信息,例如哪一行导致了错误。
根据您的错误描述,我猜测您的y_valid是一维数组,并且您试图像二维数组一样使用y_valid[:, i]
。
如果您的y_valid是目标类的一维数组。您可以尝试对y_valid进行一次热编码,使其像您使用的方式一样使用。
我希望它会有所帮助。