如何使用 object_ detector.EfficientDetLite4Spec tensorflow lite 继续使用检查点进行训练

问题描述 投票:0回答:1

珍贵的是,我在 config.yaml 中设置了我的 EfficientDetLite4 模型 “grad_checkpoint=true”。并且它已经成功生成了一些检查点。但是,当我想继续基于这些检查点进行训练时,我不知道如何使用它们。

每次我训练模型时,它都会从头开始,而不是从我的检查点开始。

下图是我的colab文件系统结构:

下图显示了我的检查点存储的位置:

以下代码显示了我如何配置模型以及如何使用模型进行训练。

import numpy as np
import os

from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

train_data, validation_data, test_data = 
    object_detector.DataLoader.from_csv('csv_path')

spec = object_detector.EfficientDetLite4Spec(
    uri='/content/model',
    model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
    hparams='grad_checkpoint=true,strategy=gpus',
    epochs=50, batch_size=3,
    steps_per_execution=1, moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, strategy=spec_strategy
)

model = object_detector.create(train_data, model_spec=spec, batch_size=3, 
    train_whole_model=True, validation_data=validation_data)
python tensorflow machine-learning tensorflow-lite
1个回答
9
投票

源代码就是答案!

我遇到了同样的问题,发现我们传递给 TFLite 模型制作者的对象检测器 API 的

model_dir
仅用于保存模型的权重:这就是 API 永远不会从检查点恢复的原因。

查看此 API 的源代码,我注意到它内部使用标准

model.compile
model.fit
函数,并通过
callbacks
model.fit
参数保存模型的权重。
这意味着,只要我们可以获得内部 keras 模型,我们就可以使用
model.load_weights
来恢复我们的检查点!

如果您想更多地了解我在下面使用的一些函数的用途,这些是源代码的链接:

这是代码:

#Useful imports
import tensorflow as tf
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
from tflite_model_maker.object_detector import DataLoader

#Import the same libs that TFLiteModelMaker interally uses
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib



#Setup variables
batch_size = 6 #or whatever batch size you want
epochs = 50
checkpoint_dir = "/content/..." #whatever your checkpoint directory is



#Create whichever object detector's spec you want
spec = object_detector.EfficientDetLite4Spec(
    model_name='efficientdet-lite4',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2', 
    hparams='', #enable grad_checkpoint=True if you want
    model_dir=checkpoint_dir, 
    epochs=epochs, 
    batch_size=batch_size,
    steps_per_execution=1, 
    moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, 
    strategy=None, 
    tpu=None, 
    gcp_project=None,
    tpu_zone=None, 
    use_xla=False, 
    profile=False, 
    debug=False, 
    tf_random_seed=111111,
    verbose=1
)



#Load you datasets
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')




#Create the object detector 
detector = object_detector.create(
    train_data, 
    model_spec=spec, 
    batch_size=batch_size, 
    train_whole_model=True, 
    validation_data=validation_data,
    epochs = epochs,
    do_train = False
)



"""
From here on we use internal/"private" functions of the API,
you can tell because the methods' names begin with an underscore
"""

#Convert the datasets for training
train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)




#Get the internal keras model    
model = detector.create_model()




#Copy what the API internally does as setup
config = spec.config
config.update(
    dict(
        steps_per_epoch=steps_per_epoch,
        eval_samples=batch_size * validation_steps,
        val_json_file=val_json_file,
        batch_size=batch_size
    )
)
train.setup_model(model, config) #This is the model.compile call basically
model.summary()




"""
Here we restore the weights
"""

#Load the weights from the latest checkpoint.
#In my case:
#checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/" 
#specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
try:
    #Option A:
    #load the weights from the last successfully completed epoch
    latest = tf.train.latest_checkpoint(checkpoint_dir) 

    #Option B:
    #load the weights from a specific checkpoint.
    #Note that there's no ".index" at the end of specific_checkpoint_dir
    #latest = specific_checkpoint_dir

    completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
    model.load_weights(latest)

    print("Checkpoint found {}".format(latest))
except Exception as e:
    print("Checkpoint not found: ", e)



#Retrieve the needed default callbacks
all_callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds)



"""
Optional step.
Add callbacks that get executed at the end of every N 
epochs: in this case I want to log the training results to tensorboard.
"""
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)
#all_callbacks.append(tensorboard_callback)




"""
Train the model 
"""
model.fit(
    train_ds,
    epochs=epochs,
    initial_epoch=completed_epochs, 
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_ds,
    validation_steps=validation_steps,
    callbacks=all_callbacks #This is for saving checkpoints at the end of every epoch + running the above added callbacks
)




"""
Save/export the trained model
Tip: for integer quantization you simply have to NOT SPECIFY 
the quantization_config parameter of the detector.export method.
In this case it would be: 
detector.export(export_dir = export_dir, tflite_filename='model.tflite')
"""
export_dir = "/content/..." #save the tflite wherever you want
quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
detector.model = model #inject our trained model into the object detector
detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)
© www.soinside.com 2019 - 2024. All rights reserved.