我在Windows 11上使用TensoFlow 2.4.0版本和Keras,我想添加一个
ModelCheckpoint
回调监控auc
:
import tensorflow as tf
try: # Atttempt to get rid of any model, history from previous runs
del model
del history
except:
print('No model to delete')
checkpoint_filepath = "checkpoint"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_auc',
mode='max',
save_best_only=True)
model.compile(optimizer=Adam(learning_rate=0.001),
loss= 'binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.AUC()])
history = model.fit(..., callbacks=[model_checkpoint_callback])
一切都很好,因为
val_auc
存在于history.history.keys()
中。
但是,如果我下次再次运行代码(在 Jupyter 笔记本中),则密钥将变为:val_auc_1
。
当然
ModelCheckpoint
不起作用。
我必须重新启动内核才能摆脱密钥末尾这个烦人的
_1
。
keras
文档建议的一个解决方案是运行model.fit
仅运行1个纪元以获得history.history.keys()
,然后在ModelCheckpoint
中使用它。但这确实很笨拙。有没有办法在不运行 model.compile
的情况下获取 model.fit
之后的公制键?或者是否可以以某种方式避免这种情况_1
?
因此,第一次运行包含
tf.keras.metrics.AUC()
的代码块时,您会生成此类的一个实例,默认情况下,该实例的名称为“auc”。看起来这个实例被存储在内存中,所以当你再次运行 tf.keras.metrics.AUC()
时,它会创建这个类的第二个实例。这些类型的类的行为是为类的每个实例分配唯一的名称(为层分配名称的 keras 函数是here),因此“_1”会附加到第二个实例。如果您运行以下命令,您可以看到此行为:
a1 = tf.keras.metrics.AUC()
a2 = tf.keras.metrics.AUC()
print(f'a1 name: {a1.name}')
print(f'a2 name: {a2.name}')
正如 @elbe 的评论中指出的,您可以通过在创建实例时分配一个名称来解决这个问题,即
tf.keras.metrics.AUC(name='auc')
;或者正如您已经指出的,您可以重新启动内核。