密集softmax层中的输出数量

问题描述 投票:1回答:1

我一直在完成Coursera课程的额外练习,遇到了我不明白的问题。Link to Collab

就我从事ML神经网络问题而言,我一直被教导多类分类问题的输出层将是Dense,其节点数等于类数。例如。狗,猫,马-3个类别= 3个节点。

但是,在笔记本中,标签中有5个类,使用len(label_tokenizer.word_index)检查,但是使用5个节点,我的结果糟透了,而使用6个节点,模型可以正常工作。

谁能解释为什么会这样吗?我找不到任何在线示例对此进行解释。干杯!

machine-learning text-classification multiclass-classification
1个回答
0
投票

我知道了。具有分类交叉熵损失的密集层的输出期望标签/目标从零开始。例如:

cat - 0
dog - 1
horse - 2

在这种情况下,密集节点的数量为3。但是,在上面的示例中,标签是使用keras标记生成器生成的,该标记器从1开始标记化(因为填充通常为0)。

label_tokenizer = Tokenizer()
label_tokenizer.fit_on_texts(labels)
# {'business': 2, 'entertainment': 5, 'politics': 3, 'sport': 1, 'tech': 4}

这导致了一个奇怪的情况,如果我们有5个密集节点,则输出类别为0-4,这与预测为1-5的标签不匹配。

我通过重新运行所有标签都减少1的代码并通过5个密集节点成功训练模型来凭经验证明,因为标签现在为0-4。

我怀疑使用标签1-5和6个密集节点可以工作,因为该模型只是了解到未使用标签0并专注于1-5。

[如果有人了解类别交叉熵的内部工作原理,请随时添加!

© www.soinside.com 2019 - 2024. All rights reserved.