将多个 TensorFlow 数据集交错在一起

问题描述 投票:0回答:4

当前的 TensorFlow 数据集交错功能基本上是一个以单个数据集作为输入的交错平面地图。考虑到当前的 API,将多个数据集交错在一起的最佳方法是什么?假设它们已经建成,并且我有一份清单。我想交替地从它们中生成元素,并且我想支持具有超过 2 个数据集的列表(即,堆叠的 zip 和交错会非常难看)。

谢谢! :)

@mrry 或许能帮上忙。

tensorflow tensorflow-datasets
4个回答
11
投票

另请参阅:


尽管这并不“干净”,但这是我想出的唯一解决方法。

datasets = [tf.data.Dataset...]

def concat_datasets(datasets):
    ds0 = tf.data.Dataset.from_tensors(datasets[0])
    for ds1 in datasets[1:]:
        ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
    return ds0

ds = tf.data.Dataset.zip(tuple(datasets)).flat_map(
    lambda *args: concat_datasets(args)
)

2
投票

扩展 user2781994 answer(经过编辑),这是我的实现方式:

import tensorflow as tf

ds11 = tf.data.Dataset.from_tensor_slices([1,2,3])
ds12 = tf.data.Dataset.from_tensor_slices([4,5,6])
ds13 = tf.data.Dataset.from_tensor_slices([7,8,9])
all_choices_ds = [ds11, ds12, ds13]

choice_dataset = tf.data.Dataset.range(len(all_choices_ds)).repeat()
ds14 = tf.contrib.data.choose_from_datasets(all_choices_ds, choice_dataset)

# alternatively:
# ds14 = tf.contrib.data.sample_from_datasets(all_choices_ds)

iterator = ds14.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    while True:
        try:
            value=sess.run(next_element)
        except tf.errors.OutOfRangeError:
            break
        print(value)

输出为:

1
4
7
2
5
8
3
6
9

2
投票

在 Tensorflow 2.0 中

tot_imm_dataset1 = 105
tot_imm_dataset2 = 55
e = tf.data.Dataset.from_tensor_slices(tf.cast([1,0,1],tf.int64)).repeat(int(tot_imm_dataset1/2)) 
f=tf.data.Dataset.range(1).repeat(int(tot_imm_dataset2-tot_imm_dataset1/2))
choice=e.concatenate(f)
datasets=[dataset2,dataset1]
dataset_rgb_compl__con_patch= tf.data.experimental.choose_from_datasets(datasets, choice)

这对我有用


0
投票

基本思想是

  1. 使用文件创建数据集(超级数据集)
  2. 交错每个文件(子数据集):使用 from_tensor_slice 加载并包装
  3. 批量处理
    import glob

    batch_size = 32
    files = glob.glob('datasets/*.npy')
    
    def read_npy_file(file_path):
      return np.load(file_path.numpy().decode())
    
    def create_dataset(files):
      dataset = tf.data.Dataset.from_tensor_slices(files)
      dataset = dataset.interleave(
        lambda x: tf.data.Dataset.from_tensor_slices(
          tf.py_function(read_npy_file, [x], tf.float32)
        ),
        cycle_length=len(files),
        num_parallel_calls=tf.data.AUTOTUNE
      )
      dataset = dataset.batch(batch_size)
      return dataset
    
    dataset = create_dataset(files)
    
    for batch in dataset:
      print(batch.shape)

比较棘手的部分是使用 tf.py_function 来包装 read_npy_file,否则传入的 arg 是一个 Tensor,它没有 numpy() 函数。

© www.soinside.com 2019 - 2024. All rights reserved.