如何使用MNIST数据集来实现超参数

问题描述 投票:0回答:1

我目前Jupyter笔记本运行的程序MNIST数据集进行分类。我试图用KNN分类要做到这一点,并正在采取一个多小时运行。我是新来的分类和超参数和似乎没有成为一个如何正确地实现其中之一任何像样的教程。任何人都可以给我如何使用超参数这种分类有什么秘诀?我已经搜索并看到GridSearchCv和RandomizedSearchCV。查看其例子看来,他们选择不同的属性名称和更改必要为他们的代码的人。我不明白这是如何能为MNIST数据集来完成,如果数据只是手写的数字。眼看只是有数字会没有必要在这种情况下,超参数?这是我的代码,我目前仍在运行。感谢您提供任何帮助。

# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# Common imports
import numpy as np
import os

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"

def save_fig(fig_id, tight_layout=True):
    image_dir = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    path = os.path.join(image_dir, fig_id + ".png")
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format='png', dpi=300)
def sort_by_target(mnist):
    reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]
    reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]
    mnist.data[:60000] = mnist.data[reorder_train]
    mnist.target[:60000] = mnist.target[reorder_train]
    mnist.data[60000:] = mnist.data[reorder_test + 60000]
    mnist.target[60000:] = mnist.target[reorder_test + 60000]
try:
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, cache=True)
    mnist.target = mnist.target.astype(np.int8) # fetch_openml() returns targets as strings
    sort_by_target(mnist) # fetch_openml() returns an unsorted dataset
except ImportError:
    from sklearn.datasets import fetch_mldata
    mnist = fetch_mldata('MNIST original')
    mnist["data"], mnist["target"]
mnist.data.shape
X, y = mnist["data"], mnist["target"]
X.shape
y.shape

#select and display some digit from the dataset
import matplotlib
import matplotlib.pyplot as plt

some_digit_index = 7201
some_digit = X[some_digit_index]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,
           interpolation="nearest")
plt.axis("off")

save_fig("some_digit_plot")
plt.show()

#print some digit's label
print('The ground truth label for the digit above is: ',y[some_digit_index])
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
#random shuffle
import numpy as np

shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
from sklearn.model_selection import cross_val_predict
from sklearn.neighbors import KNeighborsClassifier

y_train_large = (y_train >= 7)
y_train_odd = (y_train % 2 == 1)
y_multilabel = np.c_[y_train_large, y_train_odd]

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multilabel)
knn_clf.predict([some_digit])

y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3, n_jobs=-1)
f1_score(y_multilabel, y_train_knn_pred, average="macro")
python scikit-learn jupyter knn
1个回答
1
投票

对于KNN最受欢迎的超参数是n_neighbors,那就是,你认为多少的近邻的标签分配给一个新的起点。默认情况下,它被设置为5,但它未必是最好的选择。因此它往往是更好地找到最好的选择是什么,您的具体问题。

这是你将如何找到你的榜样最优超参数:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV

param_grid = {"n_neighbors" : [3,5,7]}     

KNN=KNeighborsClassifier()

grid=GridSearchCV(KNN, param_grid = param_grid , cv = 5, scoring = 'accuracy', return_train_score = False)
grid.fit(X_train,y_train)

这样做是比较您设置n_neighbors的不同值的KNN模型的性能。然后,当你这样做:

print(grid.best_score_)
print(grid.best_params_)

它会告诉你什么是最好的性能得分,并为其中的参数选择,它达到了。

这一切都无关的事实,你正在使用的数据MNIST。您可以使用此方法用于任何其他类别的任务,只要你想KNN可能是您的任务明智的选择(这可能是值得商榷的图像分类)。会从一个任务切换到另一个的唯一的事情是超参数的最佳值。

PS:我会建议不要使用y_multilabel术语,因为这可能是指一个特定的分类任务,每个数据点可以有多个标签,这是不是在MNIST(每个图像代表在时刻只有一个数字)的情况。

© www.soinside.com 2019 - 2024. All rights reserved.