如何使数据集丢失三重态

问题描述 投票:1回答:1

我正在尝试制作能够提供批量TFRecords的数据集,其中一个批次将有来自一个类别的2个随机记录,其余来自其他随机类别。

要么

批量数据集,每个类中有2个随机记录,适合该批次。

我尝试用tf.data.Dataset.from_generatortf.data.experimental.choose_from_datasets这样做,但没有成功。你对如何做到了吗?

编辑:今天我认为我实施了第二个变种。这是我正在测试它的代码。

def input_fn():
  partial1 = tf.data.Dataset.from_tensor_slices(tf.range(0, 10)).repeat().shuffle(2)
  partial2 = tf.data.Dataset.from_tensor_slices(tf.range(20, 30)).repeat().shuffle(2)
  partial3 = tf.data.Dataset.from_tensor_slices(tf.range(60, 70)).repeat().shuffle(2)
  l = [partial1, partial2, partial3]

  def gen(x):
    return tf.data.Dataset.range(x,x+1).repeat(2)

  dataset = tf.data.Dataset.range(3).flat_map(gen).repeat(10)

  choice = tf.data.experimental.choose_from_datasets(l, dataset).batch(4)
  return choice

当被唤醒时返回

[ 0  2 21 22]
[60 61  1  4]
[20 23 62 63]
[ 3  5 24 25]
[64 66  6  7]
[26 27 65 68]
[ 8  0 28 29]
[67 69  9  2]
[20 22 60 62]
[ 3  1 23 24]
[63 61  4  6]
[25 26 65 64]
[ 7  5 27 28]
[67 66  9  8]
[21 20 69 68]
tensorflow tensorflow-datasets tensorflow-estimator
1个回答
0
投票

好的,我明白了。数据集生成成功,数据随机性似乎不错。由于三胞胎是随机的而不是半硬的,因此它不是三重态损失的理想解决方案。

def input_fn(self, params):
    batch_size = params['batch_size']

    assert self.data_dir, 'data_dir is required'
    shuffle = self.is_training

    dirs = list(map(lambda x: os.path.join(x, 'train-*' if self.is_training else 'validation-*')), self.dirs)

    def prefetch_dataset(filename): 
      dataset = tf.data.TFRecordDataset( 
          filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
      return dataset

    datasets = []
    for glob in dirs:
      dataset = tf.data.Dataset.list_files(glob)
      dataset = dataset.apply( 
        tf.contrib.data.parallel_interleave( 
            prefetch_dataset, 
            cycle_length=FLAGS.num_files_infeed, 
            sloppy=True)) # if order is important 
      dataset = dataset.shuffle(batch_size, None, True).repeat().prefetch(batch_size)
      datasets.append(dataset)

    def gen(x):
      return tf.data.Dataset.range(x,x+1).repeat(2)

    choice = tf.data.Dataset.range(len(datasets)).repeat().flat_map(gen)

    dataset = tf.data.experimental.choose_from_datasets(datasets, choice).map( # apply function to each element of the dataset in parallel
        self.dataset_parser, num_parallel_calls=FLAGS.num_parallel_calls)

    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)

    return dataset
© www.soinside.com 2019 - 2024. All rights reserved.