我正在学习使用 MNIST 数据集进行分类。我遇到了一个我无法弄清楚的错误,我已经做了很多谷歌搜索,但我无能为力,也许你是专家并且可以帮助我。这是代码--
>>> from sklearn.datasets import fetch_openml
>>> mnist = fetch_openml('mnist_784', version=1)
>>> mnist.keys()
输出: dict_keys(['数据', '目标', '框架', '类别', 'feature_names', 'target_names', 'DESCR', '详细信息', 'url'])
>>> X, y = mnist["data"], mnist["target"]
>>> X.shape
输出:(70000, 784)
>>> y.shape
产量:(70000)
>>> X[0]
output:KeyError Traceback (most recent call last)
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
2897 try:
-> 2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
KeyError: 0
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
<ipython-input-10-19c40ecbd036> in <module>
----> 1 X[0]
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\frame.py in __getitem__(self, key)
2904 if self.columns.nlevels > 1:
2905 return self._getitem_multilevel(key)
-> 2906 indexer = self.columns.get_loc(key)
2907 if is_integer(indexer):
2908 indexer = [indexer]
c:\users\khush\appdata\local\programs\python\python39\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
-> 2900 raise KeyError(key) from err
2901
2902 if tolerance is not None:
KeyError: 0
fetch_openml
的API在版本之间发生了变化。在早期版本中,它返回一个 numpy.ndarray
数组。自 0.24.0
(2020 年 12 月)以来,as_frame
的 fetch_openml
参数设置为 auto
(而不是之前作为默认选项的 False
),这为您提供了 MNIST 数据的 pandas.DataFrame
。您可以通过设置 numpy.ndarray
强制将数据读取为 as_frame = False
。请参阅 fetch_openml 参考 .
我也面临同样的问题。
我曾经使用下面的代码来解决该问题。
import matplotlib as mpl
import matplotlib.pyplot as plt
# instead of some_digit = X[0]
some_digit = X.to_numpy()[0]
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()
如果您按照以下代码操作,则无需降级 scikit-learn 库:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version= 1, as_frame= False)
mnist.keys()
您将数据集作为数据框加载,以便能够访问图像,有两种方法可以做到这一点,
将数据帧转换为数组
# Transform the dataframe into an array. Check the first value
some_digit = X.to_numpy()[0]
# Reshape it to (28,28). Note: 28 x 28 = 7064, if the reshaping doesn't meet
# this you are not able to show the image
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()
变换行
# Transform the row of your choosing into an array
some_digit = X.iloc[0,:].values
# Reshape it to (28,28). Note: 28 x 28 = 7064, if the reshaping doesn't
# meet this you are not able to show the image
some_digit_image = some_digit.reshape(28,28)
plt.imshow(some_digit_image,cmap="binary")
plt.axis("off")
plt.show()
您可以在加载 mnist 数据集时添加
parser = 'auto'
作为额外参数。
我是这样导入的:
mnist = fetch_openml('mnist_784', version= 1, as_frame= False, parser='auto')