在tensorlfow数据集中,如何混合2个数据集,从原始数据中获取75%的数据集,从增强数据中获取25%?
d = tf.data.Dataset.list_files("raw_data/")\
.flat_map(tf.data.TFRecordDataset)
ad = tf.data.Dataset.list_files("augmented_data/")\
.flat_map(tf.data.TFRecordDataset)
问题是你不能在数据集对象上使用len()
,所以在迭代一个完整的纪元之前,有时很难知道确切的数量。但你可以使用take
和skip
方法来估计这个。
train_dataset = dataset.take(number_examples_for_train)
test_dataset = dataset.skip(number_examples_for_train)
这些方法是彼此的直接替代方案。 https://www.tensorflow.org/api_docs/python/tf/data/Dataset#take