我使用下面的代码生成一个量化的tflite模型
import tensorflow as tf
def representative_dataset_gen():
for _ in range(num_calibration_steps):
# Get sample input data as a numpy array in a method of your choosing.
yield [input]
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
但是根据训练后量化:
生成的模型将被完全量化,但仍采用 float 输入和输出 以方便使用。
为 Google Coral Edge TPU 编译 tflite 模型 我还需要量化输入和输出。
在模型中,我看到第一个网络层将浮点输入转换为
input_uint8
,最后一层将output_uint8
转换为浮点输出。
如何编辑 tflite 模型以去除第一个和最后一个浮动层?
我知道我可以在转换期间将输入和输出类型设置为 uint8,但这与任何优化都不兼容。唯一可用的选择是使用假量化,这会导致错误的模型。
您可以通过设置 inference_input_type 和 inference_output_type (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/lite.py #L460-L476) 到 int8.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
#The below 3 lines performs the input - output quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()
这个:
def representative_data_gen():
for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
# Model has only one input so each data point has one element.
yield [input_value]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
tflite_model_quant = converter.convert()
生成具有 Float32 输入和输出的 Float32 模型。这个:
def representative_data_gen():
for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
yield [input_value]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model_quant = converter.convert()
生成具有 UINT8 输入和输出的 UINT8 模型
您可以通过以下方式确保情况确实如此:
interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
返回:
input: <class 'numpy.uint8'>
output: <class 'numpy.uint8'>
如果你进行了完整的 UINT8 量化。您可以使用
netron
目视检查您的模型来仔细检查