班级数量必须大于1;在 sklearn 的 MultiOutputClassifier 上获得 1 门课程

问题描述 投票:0回答:1

我收到的错误是班级少于两堂。这是代码的重要部分:

texts = [element['text'] for element in train_data]
labels = [element['labels'] if element['labels'] else ['Free'] for element in train_data]

mlb = MultiLabelBinarizer(classes=G.nodes)
y_bin = mlb.fit_transform(labels)

#texts was turned in X_reduced cause i did some other changes there
X_train, X_test, y_train, y_test = train_test_split(X_reduced, y_bin, test_size=0.2, random_state=42)

multi_label_classifier = MultiOutputClassifier(SVC(kernel='linear', probability=True))
y_train=y_train.astype(np.uint8)  #i saw this in a precedent post but it didn't worked
multi_label_classifier.fit(X_train, y_train)

y_train 是一个 400X31 的矩阵,正如你所看到的,我什至添加了一个新类“Free”,以确保我没有任何矩阵行只有零。

为了更加确定我做了这些测试。

len(np.unique(y_train))

结果 --> 2

def atleast_one(matrix):
    for row in matrix:
        if 1 in row:
            continue
        else:
            print(row)
            return False
    # every row contains a 1
    return True
atleast_one(y_train)

结果 --> 正确

np.any(np.all(y_train == 0, axis=1))

结果 --> 错误

在这一切之后,我仍然无法适应这个错误,我不明白为什么。

这是错误:

ValueError                                Traceback (most recent call last)
<ipython-input-14-cb7366c3ac07> in <cell line: 55>()
     53 multi_label_classifier = MultiOutputClassifier(SVC(kernel='linear', probability=True))
     54 y_train=y_train.astype(np.uint8)
---> 55 multi_label_classifier.fit(X_train, y_train)
     56 
     57 def concatenate_row_elements(matrix):

/usr/local/lib/python3.10/dist-packages/sklearn/svm/_base.py in _validate_targets(self, y)
    747         self.class_weight_ = compute_class_weight(self.class_weight, classes=cls, y=y_)
    748         if len(cls) < 2:
--> 749             raise ValueError(
    750                 "The number of classes has to be greater than one; got %d class"
    751                 % len(cls)
ValueError: The number of classes has to be greater than one; got 1 class

P.s.这是我的第一个问题,如果我犯了一些错误,很抱歉

python scikit-learn multilabel-classification
1个回答
0
投票

您似乎有一些列(在多标签二值化和分割训练/测试之后)全部为零或全部为1。到目前为止您检查的内容涉及行,但 SVM 抱怨目标列之一。

© www.soinside.com 2019 - 2024. All rights reserved.