使用sklearn的Python MNIST数据集,选择特定的数字

问题描述 投票:1回答:1

我正在使用Sklearn在MNIST数据集上训练很少的模型,如何只使用MNIST数据集中的两位数字4和9(两个类)来训练线性模型?

  • 如何选择我的X_test,X_train, y_test,y_train
python scikit-learn mnist
1个回答
2
投票

因此,您只想使用数字4和9的图像。

您需要类似X[np.logical_or(y == 4, y == 9)]的索引:

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits

digits = load_digits()

X = digits.data
y = digits.target

#Select only the digit 4 and 9 images
X = X[np.logical_or(y == 4, y == 9)]
y = y[np.logical_or(y == 4, y == 9)]

# verify selection
np.unique(y)
#array([4, 9])

# Now split them
X_train, X_test, y_train, y_test = train_test_split(
    X, y, train_size=200, test_size=100)
© www.soinside.com 2019 - 2024. All rights reserved.