如何使用具有类别权重的张量流(tflite)模型制作器

问题描述 投票:0回答:1

我正在使用 tflite 模型制作器 进行对象检测,我真的很想将类别权重应用到我的训练过程中以对抗类别不平衡。我正在使用 Efficient-Det-Lite 型号系列。

但是,我不知道该怎么做,或者这是否可能。

我检查了 hparams 配置的 github 源代码,但我在这里找不到任何说明类权重的内容。

我尝试直接在 object_ detector_spec.py 中设置 class_weight 参数,如下所示:

model.fit(
    train_dataset,
    epochs=epochs,
    class_weight = class_weights
    steps_per_epoch=steps_per_epoch,
    callbacks=train_lib.get_callbacks(config.as_dict(), val_dataset),
    validation_data=val_dataset,
    validation_steps=validation_steps)

但输出是错误

ValueError: `class_weight` is only supported for Models with a single output.

所以我的问题是,是否可以通过较小的努力将类别权重应用于使用 TF Lite Model Maker 的训练?如果是,我该怎么做?

我当前的训练代码如下所示

import os

from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

model_name = 'efficientdet-lite0'
custom_model_dir_name = 'model_'+"1907"
epochs = 50
batch_size = 16
model_dir = 'models/'+model_name+'/'+custom_model_dir_name+'_e'+str(epochs)+'_b'+str(batch_size)
spec = object_detector.EfficientDetLite0Spec(
    model_name = model_name,
    model_dir='/home/alex/checkpoints/',
    hparams='grad_checkpoint=true,strategy=gpus',
    epochs=epochs, batch_size=batch_size,
    steps_per_execution=1, moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25
)

train_data, validation_data, test_data = object_detector.DataLoader.from_csv(file_path)

model = object_detector.create(train_data, model_spec=spec, train_whole_model=True, validation_data=validation_data)
tensorflow tensorflow-lite
1个回答
0
投票

我的回答:

不,没有简单的方法。 问题

ValueError: `class_weight` is only supported for Models with a single output.
是一个仍在 tensorflow 存储库中公开的 issue 中讨论的问题。

解决方法之一是大幅降级张量流,但不希望降级到那么远。然而,当然可以自己修改应用于模型的损失函数并对它们应用权重以获得所需的结果。

我最终所做的是在我的模型中使用预定义的focal-loss,它可以解释参数γα的类不平衡。 α-balanced-focal-loss 是处理数据集中不平衡的好方法,例如在训练样本不足的情况下。

© www.soinside.com 2019 - 2024. All rights reserved.