使用张量流的估计器API在RNN的每个时期中的权重矩阵和成本

问题描述 投票:0回答:1

我使用Estimator API训练RNN模型,我想绘制成本/纪元数字并获得最佳模型权重矩阵。在Estimator API中有可能吗?这是代码:

   classifier.train(input_fn=lambda: input_fn_train(train_x, label_train, batch_size),steps=train_steps)


   eval_result = classifier.evaluate(input_fn=lambda: input_fn_eval(test_x, label_test, batch_size))
tensorflow tensorflow-estimator
1个回答
0
投票

有可能的。您需要做的是配置您的Estimator以生成相关信息,这些信息对您决定要保留哪些权重非常有用。这可以通过检查站完成。这是模型的“保存”。传递给Estimator config=一些配置会很有用。

以下是自定义Estimator的示例:

def model_fn(features, labels, mode, params):
    #Some code is here that gives you the output of your model from where
    #you get your predictions.
    if mode == tf.estimator.ModeKeys.TRAIN or tf.estimator.ModeKeys.EVAL:
        #Some more code is here
        loss = #your loss function here
        tf.summary.scalar('loss', loss)
    if mode == tf.estimator.ModeKeys.TRAIN:
        #More code here that train your model
    if mode == tf.estimator.ModeKeys.EVAL:
        #Again more code that you use to get some evaluation metrics
    if mode == tf.estimator.ModeKeys.PREDICT:
        #Code...
    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      train_op=train_op,
                                      eval_metric_ops=eval_metric_ops)


configuration = tf.estimator.RunConfig(save_summary_steps=10,
                                       keep_checkpoint_max=30,
                                       save_checkpoints_steps=10,
                                       log_step_count_steps=10)

custom_estimator = tf.estimator.Estimator(model_fn=model_fn,
                               model_dir='model_dir',
                               config=configuration)

custom_estimator.train(input_fn=input_fn_train, steps=10000)

save_summary_steps:实际上,您可以在估算器更新摘要的步数之后想到这一点。这可能很有用,因此您可以每10步绘制一次损失。

save_checkpoints_steps:在目前状态下,您的估算器将被保存多少步。

你可以在model_dir找到这些检查站。

如果您使用的是预装Estimator,我认为摘要是预定义的,但损失功能已经存在,因此您只需配置打印摘要的频率以及保存模型状态的频率。

© www.soinside.com 2019 - 2024. All rights reserved.