所以我正在使用张量流进行迁移学习,并且我希望能够运行
history = model.fit(...) # Run initial training with base_model.trainable = False
第一次训练完成后,我可以通过解冻某些层来对其进行微调,因此如果第一个会话运行 20 个周期,我的下一个代码块将是:
# Train the model again for a few epochs
fine_tune_epochs = 10
total_epochs = len(history.epoch) + fine_tune_epochs
history_tuned = model.fit(train_set, validation_data = dev_set, initial_epoch=history.epoch[-1], epochs=total_epochs,verbose=2, callbacks=callbacks)
基本上,它将从历史中获取纪元,并从上一个纪元开始继续训练,并将这些结果保存在history_tuned
但我可能想再次训练它,并使用更多未冻结的层,因此我会再次运行 history_tuned02 并继续使用每个历史记录的纪元,这样我的图表看起来如下图所示。
从图中可以看出,它们都是连接在一起的,但实际上是两个不同的训练课程。第一个模型被冻结,然后是微调会话。您甚至可以从性能的提升开始判断微调的起点。
问题是,为了做到这一点,我必须让 Jupyter 打开几天,因为如果我关闭它,所有变量都会消失,我需要再次训练所有内容,这将花费大量时间。
我尝试使用 dill 包,但它不适用于历史记录。我也尝试过使用 %store 历史记录,但由于某种原因它也不起作用,正如您从下图在我测试的虚拟笔记本上看到的那样。
那么有没有办法,将历史变量保存在磁盘上,关闭jupyter,再次打开它,恢复历史记录并继续我的工作?即使我让 jupyter 和 VS Code 打开直到模型完成,崩溃仍然会发生。
我还在张量流上使用检查点回调,因此我保存了权重,恢复这些不是问题,但如果可能的话,我也确实需要历史记录。
更新:
当我按照建议使用 CSVLogger 回调并使用
读取它时history = pd.read_csv('demo/logs/hist.log')
然后
history.head()
输出是
您可以通过两种方式保存您的历史记录:
手动方法:
只需中断您的训练并将您的历史文件保存为字典即可:
with open('/history_dict', 'wb') as file:
pickle.dump(history.history, file)
然后您可以使用以下命令重新加载它:
history = pickle.load(open('/history_dict', "rb"))
自动化方法:
您可以创建一个简单的回调,每个纪元都存储您的历史记录。所以,即使你的训练崩溃了,它也会自动保存并可以恢复。
回调可以是这样的:
from tensorflow import keras
import tensorflow.keras.backend as K
import os
import csv
my_dir = './model_dir' # where to save history
class SaveHistory(keras.callbacks.Callback):
def on_epoch_end(self, batch, logs=None):
if ('lr' not in logs.keys()):
logs.setdefault('lr', 0)
logs['lr'] = K.get_value(self.model.optimizer.lr)
if not ('history.csv' in os.listdir(my_dir)):
with open(my_dir + 'history.csv', 'a') as f:
content = csv.DictWriter(f, logs.keys())
content.writeheader()
with open(my_dir + 'history.csv','a') as f:
content = csv.DictWriter(f, logs.keys())
content.writerow(logs)
model.fit(..., callbacks=[SaveHistory()])
要重新加载保存为
.csv
的历史记录,只需执行以下操作:
import pandas as pd
history = pd.read_csv('history.csv')
此外,我认为除了自定义回调之外,您还可以使用 CSVLogger 沿着模型检查点保存历史记录,如下所示:
history = model.fit(..., callbacks=[keras.callbacks.CSVLogger('history.csv')])
这可以用 pandas 加载回来,如上所示。