我训练的模型,并使用其保存:
saver = tf.train.Saver()
saver.save(session, './my_model_name')
除了检查点文件,它只是包含指向最近的模型的检查点,这会在当前路径以下3个文件:
我不知道这些文件包含的内容。
我想加载C ++这种模型并进行推理。该label_image例如使用ReadBinaryProto()
从单一.bp文件加载模型。我不知道如何可以从这些3个文件加载它。什么是下面的C ++相同呢?
new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')
我目前正在与这个自己奋斗,我发现它不是很简单的目前做的。关于这个问题的两个最常被引用的教程:https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.goxwm1e5j和https://medium.com/@hamedmp/exporting-trained-tensorflow-models-to-c-the-right-way-cf24b609d183#.g1gak956i
相当于
new_saver = tf.train.import_meta_graph('./my_model_name.meta')
new_saver.restore(session, './my_model_name')
只是
Status load_graph_status = LoadGraph(graph_path, &session);
假设你已经“冻结图”(使用的脚本与结合了检查点值的图形文件)。另外,看到这里的讨论:Tensorflow Different ways to Export and Run graph in C++
你的保护产生被称为“检查点V2”,并在TF 0.12中引入的。
我把它很好地工作(尽管在C ++部分文档是可怕的,所以我花了一天来解决)。有人建议converting all variables to constants或freezing the graph,但实际需要没有这些。
Python的一部分(保存)
with tf.Session() as sess:
tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')
如果创建Saver
的tf.trainable_variables()
,你可以保存自己的一些头痛和存储空间。但也许一些更复杂的模型需要的所有数据进行保存,然后删除此参数Saver
,只要确保你所创建的Saver
创建您的图形之后。这也是非常明智地放弃所有变量/层唯一的名称,否则,你可以在不同的问题上运行。
C ++部分(推论)
需要注意的是checkpointPath
不是任何现有文件的路径,只是他们共同的前缀。如果错误地放在那里路径.index
文件,TF不会告诉你,这是错误的,但它会推断过程中,由于未初始化的变量死。
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
using namespace std;
using namespace tensorflow;
...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...
auto session = NewSession(SessionOptions());
if (session == nullptr) {
throw runtime_error("Could not create Tensorflow session.");
}
Status status;
// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}
// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
throw runtime_error("Error creating graph: " + status.ToString());
}
// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
{{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
{},
{graph_def.saver_def().restore_op_name()},
nullptr);
if (!status.ok()) {
throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}
// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);
为了完整起见,这里是Python的等价的:
推理在Python
with tf.Session() as sess:
saver = tf.train.import_meta_graph('models/my-model.meta')
saver.restore(sess, tf.train.latest_checkpoint('models/'))
outputTensors = sess.run(outputOps, feed_dict=feedDict)