Keras:如何得到预测标签超过两班

问题描述 投票:3回答:1

我在Keras实现的图像分类,使用TensorFlow后端。与两个输出类的数据集,我查了预测标签如下:

if  result[0][0] == 1:
    prediction ='adathodai'
else:
    prediction ='thamarathtai'

完整的代码链接:here

有三个班,我得到[[0. 0. 1.]]输出作为结果。我如何检查两个以上类别的预测标签中,如果别的格式?

python numpy tensorflow keras
1个回答
4
投票

对于多类分类问题,其中k标签,您可以通过使用model.predict_classes()检索预测类的指标。玩具例子:

import keras
import numpy as np

# Simpel model, 3 output nodes
model = keras.Sequential()
model.add(keras.layers.Dense(3, input_shape=(10,), activation='softmax'))

# 10 random input data points
x = np.random.rand(10, 10)
model.predict_classes(x)
> array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])

如果您在列表中有标签,你可以使用预测类得到的预测标签:

labels = ['label1', 'label2', 'label3']
[labels[i] for i in model.predict_classes(x)]
> ['label2', 'label2', 'label3', 'label2', 'label3', 'label2', 'label3', 'label2', 'label2', 'label2']

引擎盖下,model.predict_classes返回最大预测类概率在预测各行的索引:

model.predict_classes(x)
> array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])
model.predict(x).argmax(axis=-1) # same thing
> array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])
© www.soinside.com 2019 - 2024. All rights reserved.