我正在使用TensorFlow 2.0学习图像分类,在实现它的同时,我收到以下警告:
tensorflow:您的输入数据用完;中断训练。确保您的数据集或生成器至少可以生成
steps_per_epoch * epochs
个批次(在这种情况下,为11280个批次)。构建数据集时,可能需要使用repeat()函数。
这是数据集大小和模型数据:
total training mango images : 752
total validation mango images : 288
history = model.fit_generator(train_generator,
validation_data=validation_generator,
steps_per_epoch=752,
epochs=15,
validation_steps=288)
我还使用图像增强技术来提高模型的效率。但是,警告仍然会中断训练,并且模型仍会过拟合。
Here is the link to my full model in github
请帮帮我!预先谢谢你
鉴于您在第24步遇到错误,并且batch_size
为32,我想问题是您的第24批只有16张图像。您可以删除这些图像,以便您的数据集与32的批处理大小对齐,或将batch_size
更改为16。
[顺便说一句,我怀疑您使用的是steps_per_epoch
和validation_steps
错误-“步骤”是指批次数,而不是样本数。仅供参考