将SavedModel转换为TFLite时不支持Operation ParseExample

问题描述 投票:1回答:1

我正在使用TensorFlow估算器来训练和保存模型,然后将其转换为.tflite。我将模型保存如下:

feat_cols = [tf.feature_column.numeric_column('feature1'),
             tf.feature_column.numeric_column('feature2'),
             tf.feature_column.numeric_column('feature3'),
             tf.feature_column.numeric_column('feature4')]

def serving_input_receiver_fn():
    """An input receiver that expects a serialized tf.Example."""
    feature_spec = tf.feature_column.make_parse_example_spec(feat_cols)
    default_batch_size = 1
    serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='tf_example')
    receiver_tensors = {'examples': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)


dnn_regressor.export_saved_model(export_dir_base='model',
                                 serving_input_receiver_fn=serving_input_receiver_fn)

当我尝试使用以下命令转换生成的.pb文件时:

tflite_convert --output_file=/tmp/foo.tflite --saved_model_dir=/tmp/saved_model

我得到一个例外,说TensorFlow Lite不支持ParseExample操作。

标准TensorFlow Lite运行时不支持模型中的某些运算符。如果这些是本机TensorFlow运算符,您可以通过传递--enable_select_tf_ops或通过在调用tf.lite.TFLiteConverter()时设置target_ops = TFLITE_BUILTINS,SELECT_TF_OPS来使用扩展运行时。否则,如果您有自定义实现,则可以使用--allow_custom_ops禁用此错误,或者在调用tf.lite.TFLiteConverter()时设置allow_custom_ops = True。以下是您正在使用的内置运算符列表:CONCATENATION,FULLY_CONNECTED,RESHAPE。以下是您需要自定义实现的运算符列表:ParseExample。

如果我尝试在没有序列化的情况下导出模型,当我尝试预测生成的.pb文件时,函数需要并清空set(),而不是我传递的输入的dict。

ValueError:在input_dict中得到意外的键:{'feature1','feature2','feature3','feature4'}:set()

我究竟做错了什么?以下是尝试保存模型而不进行任何序列化的代码

features = {
    'feature1': tf.placeholder(dtype=tf.float32, shape=[1], name='feature1'),
    'feature2': tf.placeholder(dtype=tf.float32, shape=[1], name='feature2'),
    'feature3': tf.placeholder(dtype=tf.float32, shape=[1], name='feature3'),
    'feature4': tf.placeholder(dtype=tf.float32, shape=[1], name='feature4')
}

def serving_input_receiver_fn():
    return tf.estimator.export.ServingInputReceiver(features, features)


dnn_regressor.export_savedmodel(export_dir_base='model', serving_input_receiver_fn=serving_input_receiver_fn, as_text=True)
tensorflow tensorflow-estimator tensorflow-lite
1个回答
0
投票

解决了

使用build_raw_serving_input_receiver_fn我设法导出保存的模型而不进行任何序列化:

serve_input_fun = tf.estimator.export.build_raw_serving_input_receiver_fn(
    features,
    default_batch_size=None
)

dnn_regressor.export_savedmodel(
    export_dir_base="model",
    serving_input_receiver_fn=serve_input_fun,
    as_text=True
)

注意:在进行预测时,Predictor不知道默认的signature_def,因此我需要指定它:

predict_fn = predictor.from_saved_model("model/155482...", signature_def_key="predict")

同样从.pb转换为.tflite我使用了Python API,因为我还需要指定signature_def:

converter = tf.contrib.lite.TFLiteConverter.from_saved_model('model/155482....', signature_key='predict')
© www.soinside.com 2019 - 2024. All rights reserved.