将Pix2Pix生成器导出到tflite模型。

问题描述 投票:1回答:1

我训练了一个 Pix2Pix 生成器从Tensorflow 2.0教程和我导出它在tflite这种方式。

converter = tf.lite.TFLiteConverter.from_keras_model(generator)
tflite_model = converter.convert()
open("facades.tflite", "wb").write(tflite_model)

不幸的是,我的问题似乎来自于 tf.keras.layer.BatchNormalization 当我尝试推断时。

首先,推断的结果只返回Nan值。这可以通过禁用 熔断 的实现。

其次,BatchNormalization层的行为取决于我们是在训练还是在预测。本教程 明确指出 预测 培训=True 模式。我不知道如何用tflite模式做这件事。

有一个解决方案是关于更换 批量归一化 层层 InstanceNormalization,可参见 tensorflow_addons.转换为tflite没有任何问题,但是在推理方面还是有问题,当我在解释器上调用invoke时,它崩溃了,给我返回了一个SEGFAULT.根据stackcall,它应该来自SquaredDifference操作符InstanceNormalization层。 根据stackcall,它将来自InstanceNormalization层的SquaredDifference操作符。

有没有人设法将这个TensorFlow 2.0模型转换为tflite并正确推断?怎么做?谢谢你。

PS : 我更倾向于使用BatchNormalization的解决方案,因为它是Keras中的一个标准层,因此也可以与TensorFlow javascript一起工作。

python tensorflow2.0 tensorflow-lite tensorflow.js
1个回答
© www.soinside.com 2019 - 2024. All rights reserved.