我有一个针对
MAXIM:图像增强预先训练的
JAX
模型。现在,为了减少运行时间并将其用于生产,我必须量化权重。由于无法直接转换为 ONNX,我有 2 个选择。
tf.lite.TFLiteConverter.experimental_from_jax
看这个官方示例,代码块
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)
它似乎正在使用模型中的
params
和函数predict
,以防在模型构建和训练本身时定义为
预测:
init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
和参数
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
params = get_params(opt_state)
我的问题是,如何为我的预训练模型获得这两个必需的
params
和 predict
,以便我可以尝试为我自己的模型复制示例?
所以我在官方仓库上得到了答案。这是代码:
import tensorflow as tf
from jax.experimental import jax2tf
def predict(input_img):
'''
Function to predict the output from the JAX model
'''
return model.apply({'params': flax.core.freeze(params)}, input_img)
tf_predict = tf.function(
jax2tf.convert(predict, enable_xla=False),
input_signature=[
tf.TensorSpec(shape=[1, 704, 1024, 3], dtype=tf.float32, name='input_image')
],
autograph=False)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()], tf_predict)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_float_model = converter.convert()
with open('float_model.tflite', "wb") as f: f.write(tflite_float_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()
with open('./quantized.tflite', 'wb') as f: f.write(tflite_quantized_model)
您现在可以使用
tf.lite.Interpreter
轻松加载和运行模型