TensorFlow:负载检查点,但仅部分检查(卷积层)

问题描述 投票:3回答:1

是否可以仅从一个检查点文件中加载特定层(卷积层)?

我已经对一些CNN进行了完全监督,并保存了我的进度(我正在进行对象本地化)。为了进行自动标记,我考虑从当前模型中构建一个弱监督的CNN ...但是由于弱监督的版本具有不同的完全连接层,因此我只想选择TensorFlow检查点文件的卷积过滤器。

当然,我可以手动保存相应层的权重,但是由于它们已经包含在TensorFlow的检查点文件中,因此我希望将其提取到那里,以便有一个单独的存储文件。

tensorflow store restore convolution conv-neural-network
1个回答
0
投票

TensorFlow 2.1具有许多用于加载检查点的公共设施(model.saveCheckpointsaved_model等),但是据我所知,它们都没有过滤API。因此,让我为使用TF2.1内部开发测试中的工具的困难案例提供一个摘要。

checkpoint_filename = '/path/to/our/weird/checkpoint.ckpt'
model = tf.keras.Model( ... ) # TF2.0 Model to initialize with the above checkpoint
variables_to_load = [ ... ] # List of model weight names to update.

from tensorflow.python.training.checkpoint_utils import load_checkpoint, list_variables

reader = load_checkpoint(checkpoint_filename)
for w in model.weights:
    name=w.name.split(':')[0] # See (b/29227106)
    if name in variables_to_load:
        print(f"Updating {name}")
        w.assign(reader.get_tensor(
            # (Optional) Handle variable renaming 
            {'/var_name1/in/model':'/var_name1/in/checkpoint',
             '/var_name2/in/model':'/var_name2/in/checkpoint',
             # ...  and so on
             }.get(name,name)))

注意:model.weightslist_variables可能有助于检查模型和检查点中的变量

另请注意,此方法将不会恢复模型的优化器状态。

© www.soinside.com 2019 - 2024. All rights reserved.