哪种损失函数和指标可用于负正比非常高的多标签分类?

问题描述 投票:0回答:5

我正在训练一个多标签分类模型来检测衣服的属性。我在 Keras 中使用迁移学习,重新训练 vgg-19 模型的最后几层。

属性总数为1000个,其中约99%为0。准确性、精确度、召回率等指标都失败了,因为模型可以预测全零,但仍然获得非常高的分数。二元交叉熵、汉明损失等,在损失函数的情况下还没有起作用。

我正在使用深度时尚数据集。

那么,我可以使用哪些指标和损失函数来正确测量我的模型?

python machine-learning keras multilabel-classification vgg-net
5个回答
41
投票

分类交叉熵损失或 Softmax 损失是 Softmax 激活 加上交叉熵损失。如果我们使用此损失,我们将训练 CNN 输出每个图像的 C 类的概率。它用于多类分类

您想要的是多标签分类,因此您将使用二元交叉熵损失或Sigmoid交叉熵损失。它是 Sigmoid 激活 加上交叉熵损失。与 Softmax 损失不同,它对于每个向量分量(类)都是独立的,这意味着为每个 CNN 输出向量分量计算的损失不受其他分量值的影响。这就是为什么它被用于多标签分类,其中属于某个类的元素的洞察不应该影响另一个类的决策。

现在为了处理类别不平衡,您可以使用加权 Sigmoid 交叉熵损失。所以你会根据正例的数量/比例来惩罚错误的预测。


9
投票

实际上你应该使用

tf.nn.weighted_cross_entropy_with_logits
。 它不仅适用于多标签分类,而且还具有
pos_weight
可以像您所期望的那样关注正类。


1
投票

多类和二类分类决定了输出单元的数量,即最后一层的神经元数量。 多标签和单标签决定了最终层的激活函数和损失函数应使用哪种选择。 对于单标签,标准选择是具有分类交叉熵的 Softmax;对于多标签,请切换到具有二元交叉熵的 Sigmoid 激活。

分类交叉熵:

二元交叉熵:

C
是类的数量,
m
是当前小批量中的示例数量。
L
是损失函数,
J
是成本函数。您还可以看到这里。 在损失函数中,您正在迭代不同的类。在成本函数中,您正在迭代当前小批量中的示例。


1
投票

可以参考这个github。它们具有二元、多类、多标签以及强制模型学习接近 0 和 1 或简单学习概率的选项。

https://github.com/monkeyDemon/AI-Toolbox/blob/master/computer_vision/image_classification_keras/loss_function/focal_loss.py

史蒂夫


-1
投票

我也遇到过和你类似的情况

您可以在输出层中使用 softmax 激活函数和 categorical_crossentropy 来检查其他指标,例如精度、召回率和 f1 分数,您可以使用 sklearn 库,如下所示:

from sklearn.metrics import classification_report

y_pred = model.predict(x_test, batch_size=64, verbose=1)
y_pred_bool = np.argmax(y_pred, axis=1)

print(classification_report(y_test, y_pred_bool))

至于训练阶段,据了解,准确度指标如下

model.compile(loss='categorical_crossentropy'
              , metrics=['acc'], optimizer='adam')

如果对您有帮助,您可以使用 matplotlib 绘制训练阶段的损失和准确性的训练历史记录,如下所示:

hist = model.fit(x_train, y_train, batch_size=24, epochs=1000, verbose=2,
                 callbacks=[checkpoint],
                 validation_data=(x_valid, y_valid)

                 )
# Plot training & validation accuracy values
plt.plot(hist.history['acc'])
plt.plot(hist.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
© www.soinside.com 2019 - 2024. All rights reserved.