是否可以仅从一个检查点文件中加载特定层(卷积层)?
我已经对一些CNN进行了完全监督,并保存了我的进度(我正在进行对象本地化)。为了进行自动标记,我考虑从当前模型中构建一个弱监督的CNN ...但是由于弱监督的版本具有不同的完全连接层,因此我只想选择TensorFlow检查点文件的卷积过滤器。
当然,我可以手动保存相应层的权重,但是由于它们已经包含在TensorFlow的检查点文件中,因此我希望将其提取到那里,以便有一个单独的存储文件。
TensorFlow 2.1具有许多用于加载检查点的公共设施(model.save
,Checkpoint
,saved_model
等),但是据我所知,它们都没有过滤API。因此,让我为使用TF2.1内部开发测试中的工具的困难案例提供一个摘要。
checkpoint_filename = '/path/to/our/weird/checkpoint.ckpt'
model = tf.keras.Model( ... ) # TF2.0 Model to initialize with the above checkpoint
variables_to_load = [ ... ] # List of model weight names to update.
from tensorflow.python.training.checkpoint_utils import load_checkpoint, list_variables
reader = load_checkpoint(checkpoint_filename)
for w in model.weights:
name=w.name.split(':')[0] # See (b/29227106)
if name in variables_to_load:
print(f"Updating {name}")
w.assign(reader.get_tensor(
# (Optional) Handle variable renaming
{'/var_name1/in/model':'/var_name1/in/checkpoint',
'/var_name2/in/model':'/var_name2/in/checkpoint',
# ... and so on
}.get(name,name)))
注意:model.weights
和list_variables
可能有助于检查模型和检查点中的变量
另请注意,此方法将不会恢复模型的优化器状态。