我想和Keras一起训练我的模型。我正在使用一个巨大的数据集,其中一个训练纪元具有超过30000个步骤。我的问题是我不想等待一个纪元,然后再检查验证数据集上的模型改进。有什么方法可以使Keras每隔1000步训练数据评估一次验证数据?我认为一种选择是使用回调,但是Keras是否有任何内置解决方案?
if train:
log('Start training')
history = model.fit(train_dataset,
steps_per_epoch=train_steps,
epochs=50,
validation_data=val_dataset,
validation_steps=val_steps,
callbacks=[
keras.callbacks.EarlyStopping(
monitor='loss',
patience=10,
restore_best_weights=True,
),
keras.callbacks.ModelCheckpoint(
filepath=f'model.h5',
monitor='val_loss',
save_best_only=True,
save_weights_only=True,
),
keras.callbacks.ReduceLROnPlateau(
monitor = "val_loss",
factor = 0.5,
patience = 3,
min_lr=0.001,
),
],
)
使用内置回调,您无法执行此操作。您需要实现自定义回调。
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_train_batch_begin(self, batch, logs=None):
print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
def on_train_batch_end(self, batch, logs=None):
print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
def on_test_batch_begin(self, batch, logs=None):
print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
def on_test_batch_end(self, batch, logs=None):
print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
这来自TensorFlow文档。
您可以覆盖on_train_batch_end()
函数,并且由于batch参数是整数,因此您可以根据需要验证batch % 100 == 0
,然后是self.model.predict(val_data)
等。
[请在此处检查我的答案:How to get other metrics in Tensorflow 2.0 (not only accuracy)?,以全面了解如何覆盖自定义回调函数。请注意,在您的情况下,重要的是on_train_batch_end()
而不是on_epoch_end()
。