我通过keras_metrics和sklearn.metrics得到非常不同的结果

问题描述 投票:0回答:1

我正在尝试对文本数据进行分类。我正在使用keras_metrics获得精度,召回率和f1分数。这是我的架构代码

model = Sequential()
model.add(Embedding(input_dim=500,output_dim=50,input_length=280))
model.add(Bidirectional(CuDNNLSTM(32, return_sequences = True)))
model.add(GlobalMaxPool1D())
model.add(Dense(20, activation="relu"))
model.add(Dropout(0.05))
model.add(Dense(1, activation="sigmoid"))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',km.binary_precision()])

model.fit(sequences_matrix,y_train,batch_size=128,epochs=10,
          validation_split=0.2,verbose=2)

当我使用以下代码检查测试数据时

test_sequences = tokenize.texts_to_sequences(corpus_test)
test_sequences_matrix = sequence.pad_sequences(test_sequences,maxlen=max_len)
print(model.evaluate(test_sequences_matrix,y_test))

结果如下

[0.5238178644069406, 0.7686046519944835, 0.8109305759511182]

但是当我使用sklearn.metrics检查时,结果变得更糟:

y_pred = model.predict(test_sequences_matrix, batch_size=128, verbose=1)
y_pred_bool = np.argmax(y_pred, axis=1)
print(classification_report(y_test, y_pred_bool))

sklearn的结果如下

              precision    recall  f1-score   support

           0       0.28      1.00      0.44       240
           1       0.00      0.00      0.00       620

    accuracy                           0.28       860
   macro avg       0.14      0.50      0.22       860
weighted avg       0.08      0.28      0.12       860

混淆矩阵如下。

     0     1
0   143    97
1   98    522

混淆矩阵代码

y_pred = model.predict_classes(test_sequences_matrix, batch_size=128, verbose=1)
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
python keras scikit-learn confusion-matrix precision-recall
1个回答
0
投票

似乎您尝试比较不同的方式。分类报告为您提供PrecisionRecallF1-score。Keras的evaluate()方法为您提供了[[标量测试损失,如docs中所述:]]

返回

标量测试损失(如果模型只有一个输出而没有指标)或标量列表(如果模型有多个输出和/或指标)。属性model.metrics_names将为您提供标量输出的显示标签。

因此,由于您比较不同的值,所以看起来值自然是不同的。要知道model.evaluate(test_sequences_matrix,y_test)给出了哪些值,您可以使用print(model.metrics_names),它应该给您类似['loss', 'dense_1_loss', 'dense_2_loss']的信息。有关此的更多信息,您也可以阅读此post

希望这会有所帮助,随时问。

© www.soinside.com 2019 - 2024. All rights reserved.