使用MNIST样本将每个数字的10x10网格可视化。

问题描述 投票:0回答:1

我试图从MNIST数据集中绘制10x10网格样本。每个数字有10个。下面是代码。

导入库:

import sklearn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_openml

加载数字数据:

X, Y = fetch_openml(name='mnist_784', return_X_y=True, cache=False)

绘制网格

def P1(num_examples=10):
plt.rc('image', cmap='Greys')
plt.figure(figsize=(num_examples,len(np.unique(Y))), dpi=X.shape[1])
# For each digit (from 0 to 9)
for i in np.nditer(np.unique(Y)):
    # Create a ndarray with the features of "num_examples" examples of digit "i"
    features = X[Y == i][:num_examples]
    # For each of the "num_examples" examples
    for j in range(num_examples):
        # Create subplot (from 1 to "num_digits"*"num_examples" of each digit)
        plt.subplot(len(np.unique(Y)), num_examples, i * num_examples + j + 1)
        plt.subplots_adjust(wspace=0, hspace=0)
        # Hide tickmarks and scale
        ax = plt.gca()
        # ax.set_axis_off() # Also hide axes (frame) 
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        # Plot the corresponding digit (reshaped to square matrix/image)
        dim = int(np.sqrt(X.shape[1]))
        digit = features[j].reshape((dim,dim))            
        plt.imshow(digit)
P1(10)

然而,我在这里得到了一个错误信息,说: "迭代器操作数或请求的dtype持有引用,但没有启用REFS_OK标志"

谁能帮我解决这个问题?

python-3.x scikit-learn mnist
1个回答
0
投票

这个错误来自于 nd.iter 最有可能,你不需要--也建议使用 subplotsax 而不是MATLAB式 plt 呼叫。

digits = np.unique(Y)
M = 10
dim = int(np.sqrt(X.shape[1]))

fig, axs = plt.subplots(len(digits), M, figsize=(20,20))

for i,d in enumerate(digits):
    for j in range(M):
        axs[i,j].imshow(X[Y==d][j].reshape((dim,dim)))
        axs[i,j].axis('off')
© www.soinside.com 2019 - 2024. All rights reserved.