MNIST 中 0 的误报太多:使用逻辑回归进行多类分类

问题描述 投票:0回答:1

我编写了一个基本上使用 MNIST 数据集的代码,并使用逻辑回归对 0 到 9 的数字进行分类。我使用了多个逻辑回归单元并单独训练它们,保存它们的权重,然后找到每个逻辑单元的预测标签。然后我简单地取给出最大值的逻辑单元,并将与该逻辑单元对应的数字视为预测数字

但是当我看到混淆矩阵时,它在第一列中显示了许多误报,即大多数逻辑单元进行了许多错误分类,认为它们为零。这是我获得的混淆矩阵(附为照片):

confusion matrix

我不确定我在这里做错了什么。我将在这里留下笔记本的链接:

https://colab.research.google.com/drive/1rDCW4Tpf4AMLzjxYjTrqv4c1_e0oOY75?usp=sharing

我尝试使用 MNIST 数据集实现逻辑回归并执行多类分类。

我在 Kaggle 中浏览了一下,发现了类似的东西,但他们的混淆矩阵结果是这样的:

链接到在 MNIST 分类器上使用逻辑回归实现的多类分类:

https://www.kaggle.com/code/hamzaboulahia/logistic-regression-mnist-classification

python numpy machine-learning logistic-regression multiclass-classification
1个回答
0
投票

更新:我发现了问题。我只是对大小为 (m x 10) 的预测标签矩阵进行四舍五入,然后使用它来查找要分配的数字。然而,我应该使用原始假设值并制作预测矩阵。这基本上告诉我们给定数据是特定数字的概率。然后我所要做的就是找到最大概率,然后正确预测数字。我的训练阶段非常好,只是在做出预测时遇到了问题。

© www.soinside.com 2019 - 2024. All rights reserved.