Estimator.train()和.predict()对于小的数据集来说太慢了

问题描述 投票:0回答:1

我正在尝试实现一个DQN,该DQN在同一模型上对Estimator.train()Estimator.predict()进行多次调用,每个示例都有少量示例。但是每个调用至少要花费几百毫秒到一秒以上的时间,这与小数字(例如1-20)的示例数无关。

我认为这些延迟是由于重建图表并在每次调用时保存检查点而引起的。有没有办法将相同的图形和参数保留在内存中,以进行快速的火车预测迭代或以其他方式加快速度?]

python tensorflow-estimator
1个回答
0
投票

转换为tf.keras.Model而不是Estimator,并使用tf.keras.Model.fit()代替Estimator.train()fit()没有train()的固定延迟。 Keras predict()也没有。

© www.soinside.com 2019 - 2024. All rights reserved.