在Tensor Flow中保存和恢复训练有素的LSTM

问题描述 投票:6回答:6

我使用BasicLSTMCell训练了LSTM分类器。如何保存模型并将其恢复以用于以后的分类?

tensorflow recurrent-neural-network lstm
6个回答
3
投票

我自己也在想这个。正如其他人所指出的,在TensorFlow中保存模型的常用方法是使用tf.train.Saver(),但我相信这样可以保存tf.Variables的值。我不确定tf.Variables实现中是否存在BasicLSTMCell,当你这样做时会自动保存,或者如果还有其他步骤需要采取,但是如果其他所有步骤都失败了,那么BasicLSTMCell可以很容易地保存和加载在一个pickle文件中。


5
投票

我们发现了同样的问题。我们不确定内部变量是否已保存。我们发现在创建/定义BasicLSTMCell之后必须创建保护程序。在那里,它没有被保存。


4
投票

保存和恢复模型的最简单方法是使用tf.train.Saverobject。构造函数将保存和恢复操作添加到图形中变量的所有或指定列表的图形中。 saver对象提供了运行这些ops的方法,指定了要写入或读取的检查点文件的路径。

参考:

https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html

检查点文件

变量保存在二进制文件中,大致包含从变量名到张量值的映射。

创建Saver对象时,可以选择为检查点文件中的变量选择名称。默认情况下,它为每个变量使用Variable.name属性的值。

要了解检查点中的变量,可以使用inspect_checkpoint库,特别是print_tensors_in_checkpoint_file函数。

保存变量

使用tf.train.Saver()创建一个Saver来管理模型中的所有变量。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

恢复变量

相同的Saver对象用于恢复变量。请注意,从文件还原变量时,您不必事先初始化它们。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...

1
投票

是的,LSTM细胞内部存在重量和偏差变量(实际上,所有神经网络细胞都必须在某处具有重量变量)。正如在其他答案中已经提到的那样,使用Saver对象似乎是要走的路......以合理方便的方式保存变量和(元)图。如果你想要恢复整个模型,你需要元图,而不仅仅是一些tf.Variables孤立地坐在那里。它确实需要知道它必须保存的所有变量,因此在创建图形后创建保护程序。

处理任何“有变量吗?”/“它是否正确地重用权重?”/“我是如何实际查看LSTM中的权重,它没有绑定到任何python变量?”/等一个有用的小技巧。情况是这个小片段:

for i in tf.global_variables():
    print(i)

对于vars和

for i in my_graph.get_operations():
    print (i)

对于操作。如果要查看未绑定到python var的张量,

tf.Graph.get_tensor_by_name('name_of_op:N')

其中op的名称是生成张量的操作的名称,N是您之后(可能是几个)输出张量的索引。

如果您的图表有大量操作,则张量图显示有助于查找操作名称...最常见的...


1
投票

我已经为LSTM保存和恢复制作了示例代码。我也花了很多时间来解决这个问题。请参阅此网址:https://github.com/MareArts/rnn_save_restore_test我希望能帮助您解决此问题。


-1
投票

您可以实例化tf.train.Saver对象并在训练期间调用save传递当前会话并输出检查点文件(* .ckpt)路径。只要您认为合适,您就可以调用save(例如,每隔几个时期,当验证错误消失时):

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

在分类/推理期间,您实例化另一个tf.train.Saver并调用restore传递当前会话和检查点文件以进行恢复。您可以在使用模型进行分类之前通过调用restore来调用session.run

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model
  ...

参考:https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html#saving-and-restoring

© www.soinside.com 2019 - 2024. All rights reserved.