我正在尝试使用tf.data.Dataset.from_generator()
生成训练和验证数据。
我有自己的数据生成器,可以随时进行准备:
def data_iterator(self, input_file_list, ...):
for f in input_file_list:
X, y = get_feature(f)
yield X, y
最初,我将其直接输入到tensorflow keras模型中,但是在第一批处理后遇到数据超出范围错误。然后我决定将其包装在tensorflow数据生成器中:
train_gen = lambda: data_iterator(train_files, ...)
valid_gen = lambda: data_iterator(valid_files, ...)
output_types = (tf.float32, tf.float32)
output_shapes = (tf.TensorShape([499, 13]), tf.TensorShape([2]))
train_dat = tf.data.Dataset.from_generator(train_gen,
output_types=output_types,
output_shapes=output_shapes)
valid_dat = tf.data.Dataset.from_generator(valid_gen,
output_types=output_types,
output_shapes=output_shapes)
train_dat = train_dat.repeat().batch(batch_size=128)
valid_dat = valid_dat.repeat().batch(batch_size=128)
然后适合:
model.fit(x=train_dat,
validation_data=valid_dat,
steps_per_epoch=train_steps,
validation_steps=valid_steps,
epochs=100,
callbacks=callbacks)
但是,尽管生成器中有.repeat()
,但仍然出现错误:
[BaseCollectiveExecutor :: StartAbort超出范围:序列结束
我的问题是:
.repeat()
在这里不起作用?while True
以避免这种情况吗?我觉得这可以解决它,但看起来不像是正确的方法。我在自己的生成器中添加了一会儿True,这样它就永远不会用完,而且我也不再出错了:
def data_iterator(self, input_file_list, ...):
while True;
for f in input_file_list:
X, y = get_feature(f)
yield X, y
但是,我不知道为什么.repeat()
无法用于.from_generator()