使用tf.data批处理来自多个TFRecord文件的顺序数据

问题描述 投票:2回答:1

让我们考虑将数据集分成多个TFRecord文件:

  • 1.tfrecord
  • 2.tfrecord

我想从同一TFRecord文件生成大小为t(例如3)的序列,该序列由连续元素组成,我不希望序列具有属于不同TFRecord文件的元素。

例如,如果我们有两个包含如下数据的TFRecord文件:

  • [1.tfrecord{0, 1, 2, ..., 7}
  • [2.tfrecord{1000, 1001, 1002, ..., 1007}

没有任何改组,我想得到以下批次:

  • 第一批:0, 1, 2
  • 第二批:1, 2, 3
  • ...
  • 第i批:5, 6, 7
  • [(i + 1)批:1000, 1001, 1002
  • [(i + 2)批:1001, 1002, 1003
  • ...
  • 第j批:1005, 1006, 1007
  • [(j + 1)批:0, 1, 2

[我知道如何使用tf.data.Dataset.windowtf.data.Dataset.batch生成序列数据,但我不知道如何防止序列包含来自不同文件的元素。

我正在寻找可扩展的解决方案,即该解决方案应该可以处理数百个TFRecord文件。

下面是我的失败尝试(完全可重复的示例:):>

import tensorflow as tf

# ****************************
# Generate toy TF Record files

def _create_example(i):
    example = tf.train.Features(feature={'data': tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))})
    return tf.train.Example(features=example)

def parse_fn(serialized_example):
    return tf.parse_single_example(serialized_example, {'data': tf.FixedLenFeature([], tf.int64)})['data']


num_tf_records = 2
records_per_file = 8
options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
for i in range(num_tf_records):
    with tf.python_io.TFRecordWriter('%i.tfrecord' % i, options=options) as writer:
        for j in range(records_per_file):
            example = _create_example(j + 1000 * i)
            writer.write(example.SerializeToString())
# ****************************
# ****************************


data = tf.data.TFRecordDataset(['0.tfrecord', '1.tfrecord'], compression_type='GZIP')\
            .map(lambda x: parse_fn(x))

data = data.window(3, 1, 1, True)\
           .repeat(-1)\
           .flat_map(lambda x: x.batch(3))\
           .batch(16)

data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

with tf.Session() as sess:
    sess.run(data_it.initializer)
    print(sess.run(next_element))

输出:

[[   0    1    2]   # good
 [   1    2    3]   # good
 [   2    3    4]   # good
 [   3    4    5]   # good
 [   4    5    6]   # good
 [   5    6    7]   # good
 [   6    7 1000]   # bad – mix of elements from 0.tfrecord and 1.tfrecord
 [   7 1000 1001]   # bad
 [1000 1001 1002]   # good
 [1001 1002 1003]   # good
 [1002 1003 1004]   # good
 [1003 1004 1005]   # good
 [1004 1005 1006]   # good
 [1005 1006 1007]   # good
 [   0    1    2]   # good
 [   1    2    3]]  # good

让我们考虑将数据集分成多个TFRecord文件:1.tfrecord,2.tfrecord等。我想生成大小为t(例如3)的序列,该序列由相同的连续元素组成。

python tensorflow tensorflow-datasets
1个回答
0
投票
我认为您只需要flat_map该功能即可制作windo数据集:

def make_dataset_from_filename(filename): data = tf.data.TFRecordDataset(filename, compression_type='GZIP')\ .map(lambda x: parse_fn(x)) data = data.window(3, 1, 1, True)\ .repeat(-1)\ .flat_map(lambda x: x.batch(3))\ .batch(16) tf.data.Dataset.list_files('*.tfrecord').flat_map(make_dataset_from_filename)

© www.soinside.com 2019 - 2024. All rights reserved.