如何用tf.data.Dataset.map无限期地迭代两个tf.data.Dataset的元素求和?

问题描述 投票:1回答:1

我想在基于tf.data的管道中编写一个mixup数据增强[1]函数。

我用我的训练示例生成一个tf.data.Dataset,并用我想用来扩充我的训练示例的示例。

我想将数据集_train的元素feat_train,label_train映射为feat_train + feat_aug,label_train,label_augfeat_auglabel_aug dataset_aug的元素,这样两个数据集都可以无限期地进行迭代,例如对于包含3个元素的dataset_train和具有2个元素的dataset_aug:

feat_train [0],label_train [0]-> feat_train [0] + feat_aug [0],label_train [0] + label_aug [0]feat_train [1],label_train [1]-> feat_train [1] + feat_aug [1],label_train [1] + label_aug [1]feat_train [2],label_train [2]-> feat_train [2] + feat_aug [0],label_train [2] + label_aug [0]feat_train [0],label_train [0]-> feat_train [0] + feat_aug [1],label_train [0] + label_aug [1]feat_train [1],label_train [1]-> feat_train [1] + feat_aug [0],label_train [1] + label_aug [0]...

如何在我的混搭功能中获得这种行为?是否有其他建议的方法可以对2个[[tf.data.Datasets进行无限迭代?[[1] Zhang,Hongyi,et al。 “混合:超越经验风险最小化。” arXiv预印本arXiv:1710.09412(2017)。

# files_train and files_aug are lists of TFRecord files. # parse TFRecords to get training example features and # one-hot encoded labels dataset_train = tf.data.TFRecordDataset(files_train) dataset_train = dataset_train.map( lambda x: serialized2data(x, feature_shape, class_list)) dataset_train = dataset_train.shuffle(10000) dataset_train = dataset_train.repeat() # Repeat indefinitely. # parse TFRecords to get augmentation example features and # one-hot encoded labels dataset_aug = tf.data.TFRecordDataset(files_aug) dataset_aug = dataset_aug.map( lambda x: serialized2data(x, feature_shape, class_list)) dataset_aug = dataset_aug.repeat() # Repeat indefinitely. # augment data (mixup) # Here how can I write a map function so that the features of every item # of dataset_train is mixed with an item of dataset_aug ? # something like # dataset_train = dataset_train.map( # lambda feat_train, label_train: mixup( # feat_train, label_train, feat_aug, label_aug) # ) # ? # but how can I iterate dataset_aug to get feat_aug and label_aug ? # make batch dataset_train = dataset_train.batch(batch_size, drop_remainder=True) return dataset def mixup(feat_train, label_train, feat_aug, label_aug): # Shown as an example. This will be more complicated... return (feat_train + feat_aug, label_train + label_aug) def serialized2data( serialized_data, feature_shape, class_list, data_format='channels_first', training=True): """Generate features, labels and, if training is False, filenames and times. Labels are indices of original label in class_list. Args: serialized_data: data serialized using utils.tf_utils.serialize_data feature_shape: shape of the features. Can be obtained with feature_extractor.feature_shape (see utils.feature_utils) class_list: list of class ids (used for one-hot encoding the labels) data_format: 'channels_first' (NCHW) or 'channels_last' (NHWC). Default is set to 'channels_first' because it is faster on GPU (https://www.tensorflow.org/guide/performance/overview#data_formats). """ features = { 'filename': tf.io.FixedLenFeature([], tf.string), 'times': tf.io.FixedLenFeature([2], tf.float32), 'data': tf.io.FixedLenFeature(feature_shape, tf.float32), 'labels': tf.io.FixedLenFeature([], tf.string), } example = tf.io.parse_single_example(serialized_data, features) # reshape data to channels_first format if data_format == 'channels_first': data = tf.reshape(example['data'], (1, feature_shape[0], feature_shape[1])) else: data = tf.reshape(example['data'], (feature_shape[0], feature_shape[1], 1)) # one-hot encode labels labels = tf.strings.to_number( tf.string_split([example['labels']], '#').values, out_type=tf.int32 ) # get intersection of class_list and labels labels = tf.squeeze( tf.sparse.to_dense( tf.sets.intersection( tf.expand_dims(labels, axis=0), tf.expand_dims(class_list, axis=0) ) ), axis=0 ) # sort class_list and get indices of labels in class_list class_list = tf.sort(class_list) labels = tf.where( tf.equal( tf.expand_dims(labels, axis=1), class_list) )[:,1] tf.cond( tf.math.logical_and(training, tf.equal(tf.size(labels), 0)), true_fn=lambda:myprint(tf.strings.format('File {} has no label', example['filename'])), false_fn=lambda:1 ) one_hot = tf.cond( tf.equal(tf.size(labels), 0), true_fn=lambda: tf.zeros(tf.size(class_list)), false_fn=lambda: tf.reduce_max(tf.one_hot(labels, tf.size(class_list)), 0) ) if training: return (data, one_hot) else: return (data, one_hot, example['filename'], example['times'])

我想在基于tf.data的管道中编写一个混合数据增强[1]函数。我用我的训练示例生成一个tf.data.Dataset,并用我想用来扩展的示例生成一个tf.data.Dataset ...
python tensorflow tensorflow-datasets
1个回答
0
投票
我正在提供一个示例代码,说明如何实现您所要求的目标。我分别创建了长度为3和2的train_datasetaug_dataset。两者都有图像和标签。图像的形状为(64,64,3)。 train的标签为[10,20,30],aug的标签为[1,2]。
© www.soinside.com 2019 - 2024. All rights reserved.