Tensorflow 对象检测:继续训练

问题描述 投票:0回答:3

假设我训练了一个像 ResNet 这样的预训练网络,并将其设置为在

pipeline.config file
中检测
fine_tune_checkpoint_type
属性。据我了解,这意味着我们采用模型的预训练权重,除了分类和框预测头。此外,这意味着我们可以创建自己的标签类型,然后将其作为我们要创建/训练的模型的分类和框预测头。

现在,假设我对这个网络进行了 25000 步的训练,并且想在以后继续训练,而模型不会忘记任何东西。我应该将

fine_tune_checkpoint_type
中的
pipeline.config
更改为
full
以继续训练(当然还要加载正确的检查点文件)还是我仍然应该将其设置为
detection

编辑:

这是基于这里找到的信息https://github.com/tensorflow/models/blob/master/research/object_detection/protos/train.proto:

  //   1. "classification": Restores only the classification backbone part of
  //        the feature extractor. This option is typically used when you want
  //        to train a detection model starting from a pre-trained image
  //        classification model, e.g. a ResNet model pre-trained on ImageNet.
  //   2. "detection": Restores the entire feature extractor. The only parts
  //        of the full detection model that are not restored are the box and
  //        class prediction heads. This option is typically used when you want
  //        to use a pre-trained detection model and train on a new dataset or
  //        task which requires different box and class prediction heads.
  //   3. "full": Restores the entire detection model, including the
  //        feature extractor, its classification backbone, and the prediction
  //        heads. This option should only be used when the pre-training and
  //        fine-tuning tasks are the same. Otherwise, the model's parameters
  //        may have incompatible shapes, which will cause errors when
  //        attempting to restore the checkpoint.

因此,

classification
仅提供特征提取器的分类主干部分。这意味着该模型将在网络的许多部分从头开始。

detection
恢复了整个特征提取器,但“最终结果”将被遗忘,这意味着我们可以添加自己的类并从头开始学习这些分类。

full
恢复一切,甚至是类和框预测权重。但是,只要我们不添加或删除任何类/标签,就可以了。

这是正确的吗?

tensorflow tensorflow2.0 object-detection object-detection-api
3个回答
3
投票

是的,你没看错。

fine_tune_checkpoint_type: full
中设置
piepline.config
以保留该模型学到的所有内容,直到最后一个检查点。


3
投票

是的,你可以通过设置fine_tune_checkpoint_type来配置要恢复的变量,选项是检测和分类。通过将它设置为 detection 基本上你可以从检查点恢复几乎所有变量,通过将它设置为分类,只有来自 feature_extractor 范围的变量被恢复,(骨干网络中的所有层,如 VGG,Resnet,MobileNet,它们被称为特征提取器)。

点击这里了解更多信息。


-1
投票

迪普和劳尔,你们的答案是正确的。整个星期都在寻找这个解决方案。

© www.soinside.com 2019 - 2024. All rights reserved.