使用Tensorflow Data API和Keras模型预测批次

问题描述 投票:0回答:1

假设我有一个数据集和一个Keras模型。使用tf Dataset API中的batch()将数据集分为几批。现在,我正在寻找一种有效且干净的方法来对所有测试样品进行批量预测。

我尝试了以下代码,它可以正常工作。

batch_size = 32
dataset = dataset.batch(batch_size)
predictions = keras_model.predict(dataset, steps=math.ceil(num_testing_samples / batch_size))

我想知道有没有更有效,更优雅的方法来实现这一目标?

tensorflow keras tensorflow-datasets
1个回答
0
投票

TF> = 1.14.0

您可以只设置steps=None。从tf.keras.Model.predict()的官方文档中:

如果x是tf.data数据集,而step为None,则预测将运行直到输入数据集用尽。

只需确保您的dataset对象未处于重复模式,就可以了:)。

TF 1.12.0&1.13.0

在这些版本中,对tf.data.Datasettf.keras的支持非常差。 tf.data.Dataset对象被转换为迭代器here,如果未设置here参数,则它将触发错误steps。在1.14.0中对此进行了修补。

热门问题
推荐问题
最新问题