将 Tensorflow 数据集 API 创建的数据集拆分为训练和测试?

问题描述 投票:0回答:11

有谁知道如何将Tensorflow中的数据集API(tf.data.Dataset)创建的数据集拆分为测试和训练?

tensorflow tensorflow-datasets
11个回答
95
投票

假设您有

all_dataset
类型的
tf.data.Dataset
变量:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

测试数据集现在有前 1000 个元素,其余的用于训练。


58
投票

您可以使用

Dataset.take()
Dataset.skip()
:

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

为了更通用,我给出了一个使用 70/15/15 训练/验证/测试分割的示例,但如果您不需要测试或验证集,只需忽略最后两行。

拿走

创建一个数据集,其中最多包含此数据集中的 count 个元素。

跳过

创建一个数据集,跳过该数据集中的 count 元素。

您可能还想了解一下

Dataset.shard()

创建一个仅包含该数据集 1/num_shards 的数据集。


免责声明 我在回答这个问题后偶然发现了这个问题,所以我想我应该传播爱


32
投票
这里的大多数答案都使用

take()

skip()
,这需要事先知道数据集的大小。这并不总是可能的,或者很难/很难确定。

相反,您可以做的是将数据集进行切片,以便每 N 条记录中有 1 条成为验证记录。

为了实现这一目标,我们从 0-9 的简单数据集开始:

dataset = tf.data.Dataset.range(10) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
现在对于我们的示例,我们将对其进行切片,以便我们有 3/1 的训练/验证分割。这意味着 3 条记录将进行训练,然后 1 条记录进行验证,然后重复。

split = 3 dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds) # [0, 1, 2, 4, 5, 6, 8, 9] dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds) # [3, 7]
因此,第一个 

dataset.window(split, split + 1)

 表示抓取 
split
(3) 元素,然后前进 split + 1
 元素,然后重复。 
+ 1
 有效地跳过了我们将在验证数据集中使用的 1 个元素。

flat_map(lambda ds: ds)
是因为
window()
批量返回结果,这是我们不想要的。所以我们把它压平。

然后对于验证数据,我们首先

skip(split)

,它会跳过在第一个训练窗口中抓取的第一个 
split
 个元素 
(3),因此我们从第四个元素开始迭代。然后 window(1, split + 1)
 抓取 1 个元素,前进 
split + 1
 
(4),然后重复。

 

关于嵌套数据集的注意事项:

上面的示例适用于简单的数据集,但如果数据集嵌套,
flat_map()

 将生成错误。为了解决这个问题,您可以将 
flat_map()
 替换为可以处理简单数据集和嵌套数据集的更复杂的版本:

.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))
    

8
投票
@ted 的回答会造成一些重叠。试试这个。

train_ds_size = int(0.64 * full_ds_size) valid_ds_size = int(0.16 * full_ds_size) train_ds = full_ds.take(train_ds_size) remaining = full_ds.skip(train_ds_size) valid_ds = remaining.take(valid_ds_size) test_ds = remaining.skip(valid_ds_size)

使用下面的代码进行测试。

tf.enable_eager_execution() dataset = tf.data.Dataset.range(100) train_size = 20 valid_size = 30 test_size = 50 train = dataset.take(train_size) remaining = dataset.skip(train_size) valid = remaining.take(valid_size) test = remaining.skip(valid_size) for i in train: print(i) for i in valid: print(i) for i in test: print(i)
    

4
投票
现在 Tensorflow 不包含任何工具。

您可以使用
sklearn.model_selection.train_test_split

 生成训练/评估/测试数据集,然后分别创建 
tf.data.Dataset


4
投票
您可以使用

shard

dataset = dataset.shuffle() # optional trainset = dataset.shard(2, 0) testset = dataset.shard(2, 1)

参见:

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard


4
投票
即将推出的 TensorFlow 2.10.0 将有

tf.keras.utils.split_dataset function

,请参阅 rc3 发行说明

添加了

tf.keras.utils.split_dataset

 实用程序,可将 
Dataset
 对象或数组列表/元组拆分为两个 
Dataset
 对象(例如训练/测试)。


0
投票
如果数据集的大小已知:

from typing import Tuple import tensorflow as tf def split_dataset(dataset: tf.data.Dataset, dataset_size: int, train_ratio: float, validation_ratio: float) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: assert (train_ratio + validation_ratio) < 1 train_count = int(dataset_size * train_ratio) validation_count = int(dataset_size * validation_ratio) test_count = dataset_size - (train_count + validation_count) dataset = dataset.shuffle(dataset_size) train_dataset = dataset.take(train_count) validation_dataset = dataset.skip(train_count).take(validation_count) test_dataset = dataset.skip(validation_count + train_count).take(test_count) return train_dataset, validation_dataset, test_dataset

示例:

size_of_ds = 1001 train_ratio = 0.6 val_ratio = 0.2 ds = tf.data.Dataset.from_tensor_slices(list(range(size_of_ds))) train_ds, val_ds, test_ds = split_dataset(ds, size_of_ds, train_ratio, val_ratio)
    

0
投票
将数据集分成两部分的一种可靠方法是首先确定性地将数据集中的每个项目映射到一个存储桶中,例如,

tf.strings.to_hash_bucket_fast

。然后,您可以通过按存储桶过滤将数据集分成两部分。如果您将数据分成 5 个桶,假设拆分均匀,您会得到 80-20 的拆分。

举个例子,假设您的数据集包含带有键

filename

 的字典。我们根据这个键将数据分成五个桶。通过这个 
add_fold
 函数,我们可以在字典中添加键 
"fold"

def add_fold(buckets: int): def add_(sample, label): fold = tf.strings.to_hash_bucket(sample["filename"], num_buckets=buckets) return {**sample, "fold": fold}, label return add_ dataset = dataset.map(add_fold(buckets=5))
现在我们可以使用 

Dataset.filter

 将数据集分成两个不相交的数据集:

def pick_fold(fold: int): def filter_fn(sample, _): return tf.math.equal(sample["fold"], fold) return filter_fn def skip_fold(fold: int): def filter_fn(sample, _): return tf.math.not_equal(sample["fold"], fold) return filter_fn train_dataset = dataset.filter(skip_fold(0)) val_dataset = dataset.filter(pick_fold(0))
用于散列的密钥应该是捕获数据集中的相关性的密钥。例如,如果同一个人收集的样本是相关的,并且您希望具有相同收集器的所有样本最终都位于同一个存储桶(和相同的分割)中,则应使用收集器名称或 ID 作为哈希列。

当然,您可以跳过带有

dataset.map

 的部分,并在一个 
filter
 函数中进行哈希和过滤。这是一个完整的例子:

dataset = tf.data.Dataset.from_tensor_slices([f"value-{i}" for i in range(10000)]) def to_bucket(sample): return tf.strings.to_hash_bucket_fast(sample, 5) def filter_train_fn(sample): return tf.math.not_equal(to_bucket(sample), 0) def filter_val_fn(sample): return tf.math.logical_not(filter_train_fn(sample)) train_ds = dataset.filter(filter_train_fn) val_ds = dataset.filter(filter_val_fn) print(f"Length of training set: {len(list(train_ds.as_numpy_iterator()))}") print(f"Length of validation set: {len(list(val_ds.as_numpy_iterator()))}")
打印:

Length of training set: 7995 Length of validation set: 2005
    

0
投票

谨防惰性求值,它会产生两个重叠的管道 shuffle+take

shuffle+skip
。因此,一些高分答案会产生信息泄露。这是在训练和测试数据集中重复和播种洗牌的正确方法。
import tensorflow as tf def gen_data(): return iter(range(10)) ds = tf.data.Dataset.from_generator( gen_data, output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element")) SEED = 42 # NOTE: with no seed, you overlap train and test! ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100) ds_test = ds.shuffle(100,seed=SEED).skip(8) A = set(ds_train.as_numpy_iterator()) B = set(ds_test.as_numpy_iterator()) assert A.intersection(B)==set() print(list(A)) print(list(B))

注意:这适用于任何
确定性排序
迭代器。

无法发表评论,但以上答案有重叠且不正确。将 BUFFER_SIZE 设置为 DATASET_SIZE 以实现完美的随机播放。尝试不同大小的验证/测试大小来验证。答案应该是:

-2
投票
DATASET_SIZE = tf.data.experimental.cardinality(full_dataset).numpy() train_size = int(0.7 * DATASET_SIZE) val_size = int(0.15 * DATASET_SIZE) test_size = int(0.15 * DATASET_SIZE) full_dataset = full_dataset.shuffle(BUFFER_SIZE) train_dataset = full_dataset.take(train_size) test_dataset = full_dataset.skip(train_size) val_dataset = test_dataset.take(val_size) test_dataset = test_dataset.skip(val_size)


© www.soinside.com 2019 - 2024. All rights reserved.