有谁知道如何将Tensorflow中的数据集API(tf.data.Dataset)创建的数据集拆分为测试和训练?
假设您有
all_dataset
类型的 tf.data.Dataset
变量:
test_dataset = all_dataset.take(1000)
train_dataset = all_dataset.skip(1000)
测试数据集现在有前 1000 个元素,其余的用于训练。
您可以使用
Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)
为了更通用,我给出了一个使用 70/15/15 训练/验证/测试分割的示例,但如果您不需要测试或验证集,只需忽略最后两行。
拿走:
创建一个数据集,其中最多包含此数据集中的 count 个元素。
跳过:
创建一个数据集,跳过该数据集中的 count 元素。
Dataset.shard()
:
创建一个仅包含该数据集 1/num_shards 的数据集。
免责声明 我在回答这个问题后偶然发现了这个问题,所以我想我应该传播爱
take()
和
skip()
,这需要事先知道数据集的大小。这并不总是可能的,或者很难/很难确定。相反,您可以做的是将数据集进行切片,以便每 N 条记录中有 1 条成为验证记录。
为了实现这一目标,我们从 0-9 的简单数据集开始:
dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
现在对于我们的示例,我们将对其进行切片,以便我们有 3/1 的训练/验证分割。这意味着 3 条记录将进行训练,然后 1 条记录进行验证,然后重复。
split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]
因此,第一个 dataset.window(split, split + 1)
表示抓取
split
个(3) 元素,然后前进
split + 1
元素,然后重复。
+ 1
有效地跳过了我们将在验证数据集中使用的 1 个元素。
flat_map(lambda ds: ds)
是因为
window()
批量返回结果,这是我们不想要的。所以我们把它压平。然后对于验证数据,我们首先
skip(split)
,它会跳过在第一个训练窗口中抓取的第一个
split
个元素(3),因此我们从第四个元素开始迭代。然后
window(1, split + 1)
抓取 1 个元素,前进
split + 1
(4),然后重复。
关于嵌套数据集的注意事项:
上面的示例适用于简单的数据集,但如果数据集嵌套,flat_map()
将生成错误。为了解决这个问题,您可以将
flat_map()
替换为可以处理简单数据集和嵌套数据集的更复杂的版本:
.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)
train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)
使用下面的代码进行测试。
tf.enable_eager_execution()
dataset = tf.data.Dataset.range(100)
train_size = 20
valid_size = 30
test_size = 50
train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)
for i in train:
print(i)
for i in valid:
print(i)
for i in test:
print(i)
您可以使用 sklearn.model_selection.train_test_split
生成训练/评估/测试数据集,然后分别创建
tf.data.Dataset
。
tf.keras.utils.split_dataset function
添加了
tf.keras.utils.split_dataset
实用程序,可将Dataset
对象或数组列表/元组拆分为两个Dataset
对象(例如训练/测试)。
from typing import Tuple
import tensorflow as tf
def split_dataset(dataset: tf.data.Dataset,
dataset_size: int,
train_ratio: float,
validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]:
assert (train_ratio + validation_ratio) < 1
train_count = int(dataset_size * train_ratio)
validation_count = int(dataset_size * validation_ratio)
test_count = dataset_size - (train_count + validation_count)
dataset = dataset.shuffle(dataset_size)
train_dataset = dataset.take(train_count)
validation_dataset = dataset.skip(train_count).take(validation_count)
test_dataset = dataset.skip(validation_count + train_count).take(test_count)
return train_dataset, validation_dataset, test_dataset
示例:
size_of_ds = 1001
train_ratio = 0.6
val_ratio = 0.2
ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds)))
train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
tf.strings.to_hash_bucket_fast
filename
的字典。我们根据这个键将数据分成五个桶。通过这个
add_fold
函数,我们可以在字典中添加键
"fold"
:
def add_fold(buckets: int):
def add_(sample, label):
fold = tf.strings.to_hash_bucket(sample["filename"], num_buckets=buckets)
return {**sample, "fold": fold}, label
return add_
dataset = dataset.map(add_fold(buckets=5))
现在我们可以使用 Dataset.filter
将数据集分成两个不相交的数据集:
def pick_fold(fold: int):
def filter_fn(sample, _):
return tf.math.equal(sample["fold"], fold)
return filter_fn
def skip_fold(fold: int):
def filter_fn(sample, _):
return tf.math.not_equal(sample["fold"], fold)
return filter_fn
train_dataset = dataset.filter(skip_fold(0))
val_dataset = dataset.filter(pick_fold(0))
用于散列的密钥应该是捕获数据集中的相关性的密钥。例如,如果同一个人收集的样本是相关的,并且您希望具有相同收集器的所有样本最终都位于同一个存储桶(和相同的分割)中,则应使用收集器名称或 ID 作为哈希列。当然,您可以跳过带有
dataset.map
的部分,并在一个
filter
函数中进行哈希和过滤。这是一个完整的例子:
dataset = tf.data.Dataset.from_tensor_slices([f"value-{i}" for i in range(10000)])
def to_bucket(sample):
return tf.strings.to_hash_bucket_fast(sample, 5)
def filter_train_fn(sample):
return tf.math.not_equal(to_bucket(sample), 0)
def filter_val_fn(sample):
return tf.math.logical_not(filter_train_fn(sample))
train_ds = dataset.filter(filter_train_fn)
val_ds = dataset.filter(filter_val_fn)
print(f"Length of training set: {len(list(train_ds.as_numpy_iterator()))}")
print(f"Length of validation set: {len(list(val_ds.as_numpy_iterator()))}")
打印:
Length of training set: 7995
Length of validation set: 2005
谨防惰性求值,它会产生两个重叠的管道 shuffle+take
和
shuffle+skip
。因此,一些高分答案会产生信息泄露。这是在训练和测试数据集中重复和播种洗牌的正确方法。
import tensorflow as tf
def gen_data():
return iter(range(10))
ds = tf.data.Dataset.from_generator(
gen_data,
output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))
SEED = 42 # NOTE: with no seed, you overlap train and test!
ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test = ds.shuffle(100,seed=SEED).skip(8)
A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())
assert A.intersection(B)==set()
print(list(A))
print(list(B))
注意:这适用于任何确定性排序迭代器。
无法发表评论,但以上答案有重叠且不正确。将 BUFFER_SIZE 设置为 DATASET_SIZE 以实现完美的随机播放。尝试不同大小的验证/测试大小来验证。答案应该是:
DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy()
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = full_dataset.shuffle(BUFFER_SIZE)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.take(val_size)
test_dataset = test_dataset.skip(val_size)