嗨,我写了这些代码,这完全没问题,但是不知道如何将ypred与ytest进行比较
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn import datasets
from keras.utils import to_categorical
data=datasets.load_iris()
x=data.data
y=to_categorical(gol.target)
xtrain, xtest, ytrain, ytest=train_test_split(x, y,test_size=1/3)
sc=StandardScaler()
xtrain=sc.fit_transform(xtrain)
xtest=sc.transform(xtest)
ann_model=Sequential()
ann_model.add(Dense(units=4,activation='relu', kernel_initializer='uniform', input_dim=4))
ann_model.add(Dense(units=4, activation='relu', kernel_initializer='uniform'))
ann_model.add(Dense(units=3, activation='softmax', kernel_initializer='uniform'))
ann_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
ann_model.fit(xtrain, ytrain,batch_size=8,epochs=800)
ypred=ann_model.predict(xtest)
此后,我得到这样的标准化ypred:
[9.9993205e-01, 6.7994297e-05, 1.4203579e-19],
[5.3556296e-12, 4.2303108e-02, 9.5769691e-01],
[3.1650116e-04, 9.9964631e-01, 3.7194797e-05],
[1.4751430e-05, 9.9975187e-01, 2.3338773e-04],
[9.9994361e-01, 5.6439614e-05, 6.4687055e-20],
[2.6651847e-04, 9.9968839e-01, 4.5110301e-05],
[1.6542191e-06, 9.9968910e-01, 3.0929857e-04],
[9.9991632e-01, 8.3733095e-05, 3.4217699e-19],
[5.8562500e-07, 9.9891603e-01, 1.0833564e-03],
[2.7507697e-06, 9.9960250e-01, 3.9476002e-04],
[9.9997449e-01, 2.5457492e-05, 2.2423828e-21],
[7.1067189e-14, 5.0079697e-03, 9.9499208e-01],
但是我希望我的ypred像ytest一样为一和零:
[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
[1., 0., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
我如何才能克服我的问题,谢谢您的帮助。
np.argmax
和keras.utils.to_categorical
:import numpy as np
from tensorflow.keras.utils import to_categorical
arr = np.array([[9.9993205e-01, 6.7994297e-05, 1.4203579e-19],
[5.3556296e-12, 4.2303108e-02, 9.5769691e-01],
[3.1650116e-04, 9.9964631e-01, 3.7194797e-05],
[1.4751430e-05, 9.9975187e-01, 2.3338773e-04],
[9.9994361e-01, 5.6439614e-05, 6.4687055e-20],
[2.6651847e-04, 9.9968839e-01, 4.5110301e-05],
[1.6542191e-06, 9.9968910e-01, 3.0929857e-04],
[9.9991632e-01, 8.3733095e-05, 3.4217699e-19],
[5.8562500e-07, 9.9891603e-01, 1.0833564e-03],
[2.7507697e-06, 9.9960250e-01, 3.9476002e-04],
[9.9997449e-01, 2.5457492e-05, 2.2423828e-21],
[7.1067189e-14, 5.0079697e-03, 9.9499208e-01]])