如何将检查点文件转换为tensorflow.js?

问题描述 投票:0回答:1

因为我还是初学者,所以需要一步一步的详细说明。

我尝试输入以下代码:

import tensorflow.compat.v1 as tf

meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./freeze/output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

但是我不知道在哪里编写这段代码(python?命令提示符?)并且我必须更改这段代码中的任何内容吗?就像我为 name_of_the_output_node 和 ./newcheckpoint/.meta 添加什么?

tensorflow machine-learning type-conversion tensorflow2.0 checkpointing
1个回答
0
投票

您提到的代码已过时,因为它属于 Tensorflow 1。

对于 Tensorflow 2.x,请使用此代码进行模型检查点:

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=log_dir)

model.fit(x=x_train,
          y=y_train,
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[model_checkpoint_callback])

关于加载模型,请使用取自here

的代码

您还可以单独保存模型权重以供将来使用。

© www.soinside.com 2019 - 2024. All rights reserved.