Tensorflow的Estimator停止训练

问题描述 投票:0回答:1

我正在使用Tensorflow的Estimator训练模型,并且在执行评估后突然停止训练2600步后。是不是应该继续训练直到最后一个时代结束?

def train():
    train_input_func = lambda: input_fn(mode='train')
    eval_input_func = lambda: input_fn(mode='eval')

    est_conf = tf.estimator.RunConfig(cfg.model_dir, save_checkpoints_secs=120)
    estimator = tf.estimator.Estimator(model_fn, cfg.model_dir, est_conf)


    Path(estimator.eval_dir()).mkdir(parents=True, exist_ok=True)
    train_spec = tf.estimator.TrainSpec(input_fn=train_input_func)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_func, throttle_secs=120)
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

if __name__ == '__main__':
    train()

这是input_fn功能:

def input_fn(mode=None):
        data_generator = lambda: data_loader.data_generator(mode=mode)

        dataset = tf.data.Dataset.from_generator(data_generator,
                                                 output_types=(tf.int32, tf.int32),
                                                 output_shapes=([None], [None]))

        if mode is 'train':
            dataset.shuffle(cfg.shuffle_buffer).repeat(1000)

        dataset = dataset.padded_batch(cfg.batch_size, padded_shapes=([None],[None])).prefetch(1)

        return dataset
tensorflow
1个回答
0
投票

当使用tf.estimator.train_and_evaluate,使max_steps工作,你不应该使用repeat(1000),请使用repeat(),它将无限期重复输入,并不会抛出OutOfRangeError

© www.soinside.com 2019 - 2024. All rights reserved.