我正在尝试制作能够提供批量TFRecords的数据集,其中一个批次将有来自一个类别的2个随机记录,其余来自其他随机类别。
要么
批量数据集,每个类中有2个随机记录,适合该批次。
我尝试用tf.data.Dataset.from_generator
和tf.data.experimental.choose_from_datasets
这样做,但没有成功。你对如何做到了吗?
编辑:今天我认为我实施了第二个变种。这是我正在测试它的代码。
def input_fn():
partial1 = tf.data.Dataset.from_tensor_slices(tf.range(0, 10)).repeat().shuffle(2)
partial2 = tf.data.Dataset.from_tensor_slices(tf.range(20, 30)).repeat().shuffle(2)
partial3 = tf.data.Dataset.from_tensor_slices(tf.range(60, 70)).repeat().shuffle(2)
l = [partial1, partial2, partial3]
def gen(x):
return tf.data.Dataset.range(x,x+1).repeat(2)
dataset = tf.data.Dataset.range(3).flat_map(gen).repeat(10)
choice = tf.data.experimental.choose_from_datasets(l, dataset).batch(4)
return choice
当被唤醒时返回
[ 0 2 21 22]
[60 61 1 4]
[20 23 62 63]
[ 3 5 24 25]
[64 66 6 7]
[26 27 65 68]
[ 8 0 28 29]
[67 69 9 2]
[20 22 60 62]
[ 3 1 23 24]
[63 61 4 6]
[25 26 65 64]
[ 7 5 27 28]
[67 66 9 8]
[21 20 69 68]
好的,我明白了。数据集生成成功,数据随机性似乎不错。由于三胞胎是随机的而不是半硬的,因此它不是三重态损失的理想解决方案。
def input_fn(self, params):
batch_size = params['batch_size']
assert self.data_dir, 'data_dir is required'
shuffle = self.is_training
dirs = list(map(lambda x: os.path.join(x, 'train-*' if self.is_training else 'validation-*')), self.dirs)
def prefetch_dataset(filename):
dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
return dataset
datasets = []
for glob in dirs:
dataset = tf.data.Dataset.list_files(glob)
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
prefetch_dataset,
cycle_length=FLAGS.num_files_infeed,
sloppy=True)) # if order is important
dataset = dataset.shuffle(batch_size, None, True).repeat().prefetch(batch_size)
datasets.append(dataset)
def gen(x):
return tf.data.Dataset.range(x,x+1).repeat(2)
choice = tf.data.Dataset.range(len(datasets)).repeat().flat_map(gen)
dataset = tf.data.experimental.choose_from_datasets(datasets, choice).map( # apply function to each element of the dataset in parallel
self.dataset_parser, num_parallel_calls=FLAGS.num_parallel_calls)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)
return dataset