univariate_past_history = 100
univariate_future_target = 0
x_train_uni, y_train_uni = univariate_data(uni_data, 0, TRAIN_SPLIT,
univariate_past_history,
univariate_future_target)
x_val_uni, y_val_uni = univariate_data(uni_data, TRAIN_SPLIT, None,
univariate_past_history,
univariate_future_target)
要预测的值是 y[0].numpy()
而预测值为 simple_lstm_model.predict(x)[0]
我如何计算它们之间的均方误差?
BATCH_SIZE = 256
BUFFER_SIZE = 10000
train_univariate = tf.data.Dataset.from_tensor_slices((x_train_uni, y_train_uni))
train_univariate = train_univariate.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
val_univariate = tf.data.Dataset.from_tensor_slices((x_val_uni, y_val_uni))
val_univariate = val_univariate.batch(BATCH_SIZE).repeat()
simple_lstm_model = tf.keras.models.Sequential([
tf.keras.layers.LSTM(8, input_shape=x_train_uni.shape[-2:]),
tf.keras.layers.Dense(1)
])
simple_lstm_model.compile(optimizer='adam', loss='mae')
for x, y in val_univariate.take(1):
print(simple_lstm_model.predict(x).shape)
EVALUATION_INTERVAL = 200
EPOCHS = 10
simple_lstm_model.fit(train_univariate, epochs=EPOCHS,
steps_per_epoch=EVALUATION_INTERVAL,
validation_data=val_univariate, validation_steps=50)
for x, y in val_univariate.take(3):
plot = show_plot([x[0].numpy(), y[0].numpy(),
simple_lstm_model.predict(x)[0]], 0, 'Simple LSTM model')
plot.show()
print(simple_lstm_model.predict(x)[0])
expected = y[0].numpy()
predicted = simple_lstm_model.predict(x)[0]
print(mean_squared_error(expected,predicted))
如果我像上面的那样做,我得到这个错误。TypeError: Singleton array 0.05017540446704798 cannot be considered a valid collection.
先谢谢你
两个标量之间的平方误差为 (y_true - y_pred)**2
. 这是一个标量,我们可以从这个标量中了解误差有多大或模型有多好。
两个向量的平方误差是一个向量,但由于我们想要一个标量来了解两个向量之间的偏差,我们取平均值,因此取MSE。
你可以使用下面的代码来计算MSE。
# Assuming y_train_uni is a list
# Error for single sample
y_pred = simple_lstm_model.predict(x).numpy()[0]
y_true = np.array(y_train_uni)[0]
print (np.power(y_true - y_pred, 2))
# Error for multiple samples
y_pred = simple_lstm_model.predict(x).numpy()
y_true = np.array(y_train_uni)
print (np.power(y_true - y_pred, 2).mean(())