如何使用tf.lite将预训练的JAX模型量化为TfLite模型?

问题描述 投票:0回答:1

我有一个针对

MAXIM:图像增强
预先训练的 JAX 模型。现在,为了减少运行时间并将其用于生产,我必须量化权重。由于无法直接转换为 ONNX,我有 2 个选择。

  1. JAX -> Tensorflow -> ONNX(帮助线程
  2. JAX -> TFLite

选择第二个选项,有这个功能

tf.lite.TFLiteConverter.experimental_from_jax

这个官方示例,代码块

serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)

它似乎正在使用模型中的

params
和函数
predict
,以防在模型构建和训练本身时定义为

预测

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

参数

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
params = get_params(opt_state)

我的问题是,如何为我的预训练模型获得这两个必需的

params
predict
,以便我可以尝试为我自己的模型复制示例?

tensorflow tensorflow2.0 tensorflow-lite jax tflite
1个回答
0
投票

所以我在官方仓库上得到了答案。这是代码:

import tensorflow as tf
from jax.experimental import jax2tf


def predict(input_img):
  '''
  Function to predict the output from the JAX model
  '''
  return model.apply({'params': flax.core.freeze(params)}, input_img)


tf_predict = tf.function(
    jax2tf.convert(predict, enable_xla=False),
    input_signature=[
        tf.TensorSpec(shape=[1, 704, 1024, 3], dtype=tf.float32, name='input_image')
    ],
    autograph=False)

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_predict.get_concrete_function()], tf_predict)

converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_float_model = converter.convert()

with open('float_model.tflite', "wb") as f: f.write(tflite_float_model)


converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()

with open('./quantized.tflite', 'wb') as f: f.write(tflite_quantized_model)

您现在可以使用

tf.lite.Interpreter

轻松加载和运行模型
© www.soinside.com 2019 - 2024. All rights reserved.