我正在使用TensorFlow估算器来训练和保存模型,然后将其转换为.tflite。我将模型保存如下:
feat_cols = [tf.feature_column.numeric_column('feature1'),
tf.feature_column.numeric_column('feature2'),
tf.feature_column.numeric_column('feature3'),
tf.feature_column.numeric_column('feature4')]
def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
feature_spec = tf.feature_column.make_parse_example_spec(feat_cols)
default_batch_size = 1
serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='tf_example')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
dnn_regressor.export_saved_model(export_dir_base='model',
serving_input_receiver_fn=serving_input_receiver_fn)
当我尝试使用以下命令转换生成的.pb文件时:
tflite_convert --output_file=/tmp/foo.tflite --saved_model_dir=/tmp/saved_model
我得到一个例外,说TensorFlow Lite不支持ParseExample操作。
标准TensorFlow Lite运行时不支持模型中的某些运算符。如果这些是本机TensorFlow运算符,您可以通过传递--enable_select_tf_ops或通过在调用tf.lite.TFLiteConverter()时设置target_ops = TFLITE_BUILTINS,SELECT_TF_OPS来使用扩展运行时。否则,如果您有自定义实现,则可以使用--allow_custom_ops禁用此错误,或者在调用tf.lite.TFLiteConverter()时设置allow_custom_ops = True。以下是您正在使用的内置运算符列表:CONCATENATION,FULLY_CONNECTED,RESHAPE。以下是您需要自定义实现的运算符列表:ParseExample。
如果我尝试在没有序列化的情况下导出模型,当我尝试预测生成的.pb文件时,函数需要并清空set(),而不是我传递的输入的dict。
ValueError:在input_dict中得到意外的键:{'feature1','feature2','feature3','feature4'}:set()
我究竟做错了什么?以下是尝试保存模型而不进行任何序列化的代码
features = {
'feature1': tf.placeholder(dtype=tf.float32, shape=[1], name='feature1'),
'feature2': tf.placeholder(dtype=tf.float32, shape=[1], name='feature2'),
'feature3': tf.placeholder(dtype=tf.float32, shape=[1], name='feature3'),
'feature4': tf.placeholder(dtype=tf.float32, shape=[1], name='feature4')
}
def serving_input_receiver_fn():
return tf.estimator.export.ServingInputReceiver(features, features)
dnn_regressor.export_savedmodel(export_dir_base='model', serving_input_receiver_fn=serving_input_receiver_fn, as_text=True)
解决了
使用build_raw_serving_input_receiver_fn我设法导出保存的模型而不进行任何序列化:
serve_input_fun = tf.estimator.export.build_raw_serving_input_receiver_fn(
features,
default_batch_size=None
)
dnn_regressor.export_savedmodel(
export_dir_base="model",
serving_input_receiver_fn=serve_input_fun,
as_text=True
)
注意:在进行预测时,Predictor不知道默认的signature_def,因此我需要指定它:
predict_fn = predictor.from_saved_model("model/155482...", signature_def_key="predict")
同样从.pb转换为.tflite我使用了Python API,因为我还需要指定signature_def:
converter = tf.contrib.lite.TFLiteConverter.from_saved_model('model/155482....', signature_key='predict')