tensorflow数据集from_generator()超出范围错误

问题描述 投票:1回答:1

我正在尝试使用tf.data.Dataset.from_generator()生成训练和验证数据。

我有自己的数据生成器,可以随时进行准备:

def data_iterator(self, input_file_list, ...):
    for f in input_file_list:
        X, y = get_feature(f)
        yield X, y

最初,我将其直接输入到tensorflow keras模型中,但是在第一批处理后遇到数据超出范围错误。然后我决定将其包装在tensorflow数据生成器中:

train_gen = lambda: data_iterator(train_files, ...)
valid_gen = lambda: data_iterator(valid_files, ...)

output_types = (tf.float32, tf.float32)
output_shapes = (tf.TensorShape([499, 13]), tf.TensorShape([2]))
train_dat = tf.data.Dataset.from_generator(train_gen,
                                           output_types=output_types,
                                           output_shapes=output_shapes)
valid_dat = tf.data.Dataset.from_generator(valid_gen,
                                           output_types=output_types,
                                           output_shapes=output_shapes)
train_dat = train_dat.repeat().batch(batch_size=128)
valid_dat = valid_dat.repeat().batch(batch_size=128)

然后适合:

model.fit(x=train_dat,
          validation_data=valid_dat,
          steps_per_epoch=train_steps,
          validation_steps=valid_steps,
          epochs=100,
          callbacks=callbacks)

但是,尽管生成器中有.repeat(),但仍然出现错误:

[BaseCollectiveExecutor :: StartAbort超出范围:序列结束

我的问题是:

  • 为什么.repeat()在这里不起作用?
  • 我应该在自己的迭代器中添加while True以避免这种情况吗?我觉得这可以解决它,但看起来不像是正确的方法。
python tensorflow keras generator tensorflow-datasets
1个回答
1
投票

我在自己的生成器中添加了一会儿True,这样它就永远不会用完,而且我也不再出错了:

def data_iterator(self, input_file_list, ...):
    while True;
        for f in input_file_list:
            X, y = get_feature(f)
            yield X, y

但是,我不知道为什么.repeat()无法用于.from_generator()

© www.soinside.com 2019 - 2024. All rights reserved.