Python ValueError:n_splits = 3不能大于每个类中的成员数

问题描述 投票:0回答:2

我正在人脸识别项目中,每个人有两个人,每个人有2张脸

1. personA
    image1.jpg
    image2.jpg


2. personB
    image1.jpg
    image2.jpg

我正在尝试在上述数据集的面部嵌入上训练模型,如下所示:

params = {"C": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], "gamma": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]}
model = GridSearchCV(SVC(kernel="rbf", gamma="auto", probability=True), params, cv=3, n_jobs=-1)
model.fit(data["embeddings"], labels)

[data["embeddings"]labels的长度为4data["embeddings']包含人员A,人员B的面部嵌入的ndarray

data['embeddings'] = [
                         [0.02331057, -0.01995077, ..], 
                         [-0.00034041,  0.02753334, ..], 
                         [0.02454563, -0.03797123, ...], 
                         [0.10561685, -0.08444008, ...]
                     ]

labels = [0 0 1 1]

但是我在model.fit(data["embeddings"], labels)处遇到错误:

ValueError: n_splits=3 cannot be greater than the number of members in each class.

我无法理解此错误。谁能解释这个问题,我该如何解决?

python machine-learning scikit-learn cross-validation
2个回答
1
投票

仔细阅读后,错误消息清晰易懂;它只是告诉您,由于每个类别总共只有两(2)个样本,因此无法进行3折交叉验证。每个班级需要至少

3个样本。

0
投票

您将交叉验证分割(cv)设置为3。由于只有两个输入数据,因此无法分为3个分割。您可以添加第三个训练示例(第三人称),也可以将拆分更改为cv = 2或cv = None。

© www.soinside.com 2019 - 2024. All rights reserved.