因为我还是初学者,所以需要一步一步的详细说明。
我尝试输入以下代码:
import tensorflow.compat.v1 as tf
meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node'] # Output nodes
with tf.Session() as sess:
# Restore the graph
saver = tf.train.import_meta_graph(meta_path)
# Load weights
saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names)
# Save the frozen graph
with open('./freeze/output_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
但是我不知道在哪里编写这段代码(python?命令提示符?)并且我必须更改这段代码中的任何内容吗?就像我为 name_of_the_output_node 和 ./newcheckpoint/.meta 添加什么?
您提到的代码已过时,因为它属于 Tensorflow 1。
对于 Tensorflow 2.x,请使用此代码进行模型检查点:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
关于加载模型,请使用取自here
的代码您还可以单独保存模型权重以供将来使用。