如何量化优化的tflite模型的输入和输出

问题描述 投票:0回答:3

我使用下面的代码生成一个量化的tflite模型

import tensorflow as tf

def representative_dataset_gen():
  for _ in range(num_calibration_steps):
    # Get sample input data as a numpy array in a method of your choosing.
    yield [input]

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()

但是根据训练后量化

生成的模型将被完全量化,但仍采用 float 输入和输出 以方便使用。

Google Coral Edge TPU 编译 tflite 模型 我还需要量化输入和输出。

在模型中,我看到第一个网络层将浮点输入转换为

input_uint8
,最后一层将
output_uint8
转换为浮点输出。 如何编辑 tflite 模型以去除第一个和最后一个浮动层?

我知道我可以在转换期间将输入和输出类型设置为 uint8,但这与任何优化都不兼容。唯一可用的选择是使用假量化,这会导致错误的模型。

python tensorflow-lite quantization google-coral
3个回答
2
投票

您可以通过设置 inference_input_type 和 inference_output_type (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/lite.py #L460-L476) 到 int8.


1
投票
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir
converter.optimizations = [tf.lite.Optimize.DEFAULT] 
converter.representative_dataset = representative_dataset
#The below 3 lines performs the input - output quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()

1
投票

这个:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

tflite_model_quant = converter.convert()

生成具有 Float32 输入和输出的 Float32 模型。这个:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()

生成具有 UINT8 输入和输出的 UINT8 模型

您可以通过以下方式确保情况确实如此:

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

返回:

input:  <class 'numpy.uint8'>
output:  <class 'numpy.uint8'>

如果你进行了完整的 UINT8 量化。您可以使用

netron

目视检查您的模型来仔细检查
© www.soinside.com 2019 - 2024. All rights reserved.