我正在尝试将保存的模型从 Superpoint 的 Tensorflow 实现转换为 tflite 模型,以便在 Android 上进行测试。
我首先从 github 下载保存的模型: https://github.com/rpautrat/SuperPoint/tree/master/pretrained_models
模型采用 SavedModel 格式。使用以下方法检查模型的输入和输出时:
saved_model_cli show --dir sp_v6 --all
我得到以下输出:
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['image'] tensor_info:
dtype: DT_FLOAT
shape: (-1, -1, -1, 1)
name: superpoint/image:0
The given SavedModel SignatureDef contains the following output(s):
outputs['descriptors'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, -1, 256)
name: superpoint/descriptors:0
outputs['descriptors_raw'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, -1, 256)
name: superpoint/descriptors_raw:0
outputs['logits'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, -1, 65)
name: superpoint/logits:0
outputs['pred'] tensor_info:
dtype: DT_INT32
shape: (1, -1, -1)
name: superpoint/pred:0
outputs['prob'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, -1)
name: superpoint/prob:0
outputs['prob_nms'] tensor_info:
dtype: DT_FLOAT
shape: (1, -1, -1)
name: superpoint/prob_nms:0
Method name is: tensorflow/serving/predict
据我所知,Android中的tflite模型无法处理动态输入,因此我尝试使用以下代码将输入更改为固定输入:
#use tensorflow v1
import tensorflow as tf
def frozen_graph_maker(export_dir,output_graph):
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], export_dir)
output_nodes = ['superpoint/logits', 'superpoint/prob', 'superpoint/descriptors_raw', 'superpoint/descriptors', 'superpoint/prob_nms', 'superpoint/pred']
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
sess.graph_def,
output_nodes# The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
if __name__ == "__main__":
export_dir='./sp_v6/'
output_graph = "./frozen_graph.pb"
frozen_graph_maker(export_dir,output_graph)
这个游戏是一个冻结的图表,我在其中使用以下方法更改输入大小:
#Use tensorflow v2
import tensorflow.compat.v1 as tf
output_graph = "./new_frozen_graph.pb"
def load_frozen_graph(frozen_file='frozen.pb'):
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(frozen_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return graph
graph = load_frozen_graph('./frozen_graph.pb')
print('Tensor shapes before import map')
input_tensor = graph.get_tensor_by_name('superpoint/image:0')
print(input_tensor)
new_graph = tf.Graph()
with new_graph.as_default():
new_input = tf.placeholder(dtype=tf.float32, shape=[1, 320, 320, 1], name='superpoint/image')
tf.import_graph_def(graph.as_graph_def(), name='', input_map={'superpoint/image:0': new_input}, return_elements=['superpoint/logits:0', 'superpoint/prob:0', 'superpoint/descriptors_raw:0', 'superpoint/descriptors:0', 'superpoint/prob_nms:0', 'superpoint/pred:0'])
with tf.Session(graph=new_graph) as sess:
output_nodes = ['superpoint/logits', 'superpoint/prob', 'superpoint/descriptors_raw', 'superpoint/descriptors', 'superpoint/prob_nms', 'superpoint/pred']
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
sess.graph_def,
output_nodes# The output node names are used to select the usefull nodes
)
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
这个冻结图的输入大小发生了变化,然后我转换为savedModel格式,这样我就可以将其转换为tflite格式。
import tensorflow as tf
import os
import shutil
from tensorflow.python import ops
def get_graph_def_from_file(graph_filepath):
tf.compat.v1.reset_default_graph()
with ops.Graph().as_default():
with tf.compat.v1.gfile.GFile(graph_filepath, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def convert_graph_def_to_saved_model(export_dir, graph_filepath, input_name, outputs):
graph_def = get_graph_def_from_file(graph_filepath)
with tf.compat.v1.Session(graph=tf.Graph()) as session:
tf.import_graph_def(graph_def, name='')
tf.compat.v1.saved_model.simple_save(
session,
export_dir,# change input_image to node.name if you know the name
inputs={input_name: session.graph.get_tensor_by_name('{}:0'.format(node.name))
for node in graph_def.node if node.op=='Placeholder'},
outputs={t.rstrip(":0"):session.graph.get_tensor_by_name(t) for t in outputs}
)
print('Graph converted to SavedModel!')
tf.compat.v1.enable_eager_execution()
input_name="superpoint/image"
outputs = ['superpoint/logits:0', 'superpoint/prob:0', 'superpoint/descriptors_raw:0', 'superpoint/descriptors:0', 'superpoint/prob_nms:0', 'superpoint/pred:0']
shutil.rmtree('./saved_model', ignore_errors=True)
convert_graph_def_to_saved_model('./saved_model', './new_frozen_graph.pb', input_name, outputs)
检查此修改模型的输入和输出,我得到以下结果:
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['superpoint/image'] tensor_info:
dtype: DT_FLOAT
shape: (1, 320, 320, 1)
name: superpoint/image:0
The given SavedModel SignatureDef contains the following output(s):
outputs['superpoint/descriptors'] tensor_info:
dtype: DT_FLOAT
shape: (1, 320, 320, 256)
name: superpoint/descriptors:0
outputs['superpoint/descriptors_raw'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40, 40, 256)
name: superpoint/descriptors_raw:0
outputs['superpoint/logits'] tensor_info:
dtype: DT_FLOAT
shape: (1, 40, 40, 65)
name: superpoint/logits:0
outputs['superpoint/pred'] tensor_info:
dtype: DT_INT32
shape: unknown_rank
name: superpoint/pred:0
outputs['superpoint/prob'] tensor_info:
dtype: DT_FLOAT
shape: (1, 320, 320)
name: superpoint/prob:0
outputs['superpoint/prob_nms'] tensor_info:
dtype: DT_FLOAT
shape: unknown_rank
name: superpoint/prob_nms:0
Method name is: tensorflow/serving/predict
输入现在看起来不错,但某些输出的排名未知。但我尝试继续进行 tflite 转换,如下所示:
import tensorflow.lite as lite
saved_model_dir = './saved_model'
converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
with open('./new_frozen.tflite', 'wb') as w:
w.write(tflite_model)
但是在转换过程中我收到以下不匹配错误:
tensorflow.lite.python.convert.ConverterError: <unknown>:0: error: type of return operand 3 ('tensor<?x?x?xi32>') doesn't match function result type ('tensor<1x?x?xi32>') in function @main
<unknown>:0: note: see current operation: "std.return"(%32, %30, %36, %47, %40, %45) : (tensor<1x320x320x256xf32>, tensor<1x40x40x256xf32>, tensor<1x40x40x65xf32>, tensor<?x?x?xi32>, tensor<1x320x320xf32>, tensor<?x?x?xf32>) -> ()
如果有人可以帮助我,告诉我我做错了什么,以及是否有一种更简单的方法将此张量流模型转换为 tflite 以进行 Android 推理,那就太好了。我确实尝试从图中删除两个形状未知的输出,然后转换起作用,但我想用这些输出来转换它。谢谢你。
从代码部分“这给了我一个冻结的图表,我在其中使用以下命令更改输入大小:”...您可以用以下代码替换以下所有代码:
来源:https://github.com/tensorflow/tensorflow/issues/38388#issuecomment-642388448
(以下代码适用于 TensorFlow 版本 1 或 2)
import tensorflow as tf
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file='./frozen_graph.pb',
input_arrays=['superpoint/image'],
output_arrays=['superpoint/logits', 'superpoint/prob', 'superpoint/descriptors_raw', 'superpoint/descriptors', 'superpoint/prob_nms', 'superpoint/pred'],
input_shapes={'superpoint/image' : [1, 320, 320, 1]}
)
tflite_model = converter.convert()
tflite_model_size = open('model.tflite', 'wb').write(tflite_model)
print('TFLite Model is %d bytes' % tflite_model_size)
以防有人仍在为此苦苦挣扎。这个特定图的问题是,
superpoint/prob_nms
分支上的最后一个操作导致形状不匹配(我相信他们正在重塑数组,使其具有形状[1, H, W]
,而不考虑批量大小,因此对于输入签名批量尺寸为 None
时,会导致形状不匹配错误。这也可以通过直接在保存的模型上输入批量大小大于 1 的张量来重现。
一个快速的解决方案是跳过最后一个操作并强制进行非批量输入。这可以简单地通过以下代码来实现(导出到冻结图后):
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file=graph_def_file,
input_arrays=['superpoint/image'],
output_arrays=['superpoint/unstack_4', 'superpoint/descriptors'],
input_shapes={'superpoint/image' : [1, None, None, 1]}
)
superpoint/unstack_4
的输出形状不会被批量处理,并且模型一次只能处理单个图像,但这对于大多数情况来说应该足够了。
TFLite 通常支持固定大小的输入和输出。 因此,对于 superpoint,只有“superpoint/prob”和“superpoint/descriptors”应该转换为输出(logits 和descriptors_raw 是中间输出,在 tflite 中应该没有用处) 所以一切都在 TF2 中:
import tensorflow as tf
export_dir = "./sp_v6"
H = 240
W = 320
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
tf.compat.v1.saved_model.loader.load(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING], export_dir)
# output_nodes = ['superpoint/logits', 'superpoint/prob', 'superpoint/descriptors_raw',
# 'superpoint/descriptors', 'superpoint/prob_nms', 'superpoint/pred']
# logits: detector 原始输出
# prob: depth_to_space后,pixel-level的特征点score map
# prob_nms: prob+NMS
# pred: prob中置信度 > detection_threshold (0.4)的
# descriptors_raw: descriptor 原始输出
# descriptors: resize_bilinear+l2_normalize
output_nodes = ['superpoint/prob',
'superpoint/descriptors']
# graph def
init_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
sess.graph_def,
output_nodes # The output node names are used to select the usefull nodes
)
# fix input shape
frozen_input = tf.compat.v1.placeholder(dtype=tf.float32, shape=[
1, H, W, 1], name='superpoint/image')
tf.import_graph_def(init_graph_def, name='', input_map={'superpoint/image:0': frozen_input}, return_elements=[
'superpoint/prob:0', 'superpoint/descriptors:0'])
output_nodes = ['superpoint/prob', 'superpoint/descriptors']
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
sess.graph_def,
output_nodes # The output node names are used to select the usefull nodes
)
with tf.compat.v1.gfile.GFile("frozen_graph.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file='./frozen_graph.pb',
input_arrays=['superpoint/image'],
output_arrays=['superpoint/prob', 'superpoint/descriptors'],
input_shapes={'superpoint/image': [1, H, W, 1]}
)
# Float16, CPU/GPU
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()
tflite_model_size = open('superpoint_{}x{}.tflite'.format(
H, W), 'wb').write(tflite_model)
print('TFLite Model is {} bytes'.format(tflite_model_size))