Tensorflow object_detection保存和加载微调模型的正确方法

问题描述 投票:0回答:2

我正在使用 colabs 教程中的 this 示例来微调模型,训练后我想使用以下方法保存模型并加载到本地计算机上:

ckpt_manager = tf.train.CheckpointManager(ckpt, directory="test_data/checkpoint/", max_to_keep=5)
...
...
print('Done fine-tuning!')

ckpt_manager.save()
print('Checkpoint saved!')

但是在我的本地计算机上使用检查点文件恢复后没有检测到任何对象(分数太低)

我也尝试过

tf.saved_model.save(detection_model, '/content/new_model/')

并加载这个:

detection_model = tf.saved_model.load('/saved_model_20201226/')

input_tensor = tf.convert_to_tensor(image, dtype=tf.float32)
detections = detection_model(input_tensor)

给我这个错误: 类型错误:“_UserObject”对象不可调用

保存和加载微调模型的正确方法是什么?

编辑1: 正在等待保存新的管道配置,之后终于成功了! 这是我的回答:

# Save new pipeline config
new_pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(new_pipeline_proto, '/content/new_config')
exported_ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt_manager = tf.train.CheckpointManager(
exported_ckpt, directory="test_data/checkpoint/", max_to_keep=5)
...
...
print('Done fine-tuning!')

ckpt_manager.save()
print('Checkpoint saved!')
tensorflow machine-learning deep-learning object-detection
2个回答
1
投票

等待保存新的管道配置,之后终于成功了!这是我的回答:

# Save new pipeline config
new_pipeline_proto = config_util.create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(new_pipeline_proto, '/content/new_config')

exported_ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt_manager = tf.train.CheckpointManager(
exported_ckpt, directory="test_data/checkpoint/", max_to_keep=5)
...
...
print('Done fine-tuning!')

ckpt_manager.save()
print('Checkpoint saved!')

0
投票

但是当您尝试从经过训练的检查点构建模型时,模型性能与模型在本地笔记本上的原始性能不同。为什么这是一个问题?

© www.soinside.com 2019 - 2024. All rights reserved.