如何分割 Tensorflow 数据集?

问题描述 投票:0回答:3

我有一个基于一个 .tfrecord 文件的张量流数据集。如何将数据集拆分为测试数据集和训练数据集?例如。 70% 训练,30% 测试?

编辑:

我的张量流版本:1.8 我已经检查过,没有可能的重复项中提到的“split_v”函数。我也在使用 tfrecord 文件。

tensorflow tensorflow-datasets
3个回答
56
投票

您可以使用

Dataset.take()
Dataset.skip()
:

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)

为了更通用,我给出了一个使用 70/15/15 训练/验证/测试分割的示例,但如果您不需要测试或验证集,只需忽略最后两行。

拿走

创建一个数据集,其中最多包含此数据集中的 count 个元素。

跳过

创建一个数据集,跳过该数据集中的 count 元素。

您可能还想了解一下

Dataset.shard()

创建一个仅包含该数据集 1/num_shards 的数据集。


43
投票

这个问题与这个这个类似,恐怕我们还没有得到满意的答案。

  • 使用

    take()
    skip()
    需要知道数据集大小。如果我不知道或者不想知道怎么办?

  • 使用

    shard()
    仅给出数据集的
    1 / num_shards
    。如果我想要剩下的怎么办?

我尝试在下面提出一个更好的解决方案,仅在 TensorFlow 2 上进行了测试。假设您已经有一个 shuffled 数据集,则可以使用

filter()
将其分成两部分:

import tensorflow as tf

all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
        .shuffle(10, reshuffle_each_iteration=False)

test_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 == 0) \
                    .map(lambda x,y: y)

train_dataset = all.enumerate() \
                    .filter(lambda x,y: x % 4 != 0) \
                    .map(lambda x,y: y)

for i in test_dataset:
    print(i)

print()

for i in train_dataset:
    print(i)

参数

reshuffle_each_iteration=False
很重要。它确保原始数据集被打乱一次,不再打乱。否则,两个结果集可能会有一些重叠。

使用

enumerate()
添加索引。

使用

filter(lambda x,y: x % 4 == 0)
从 4 个样本中取出 1 个样本。同样,
x % 4 != 0
从 4 个样本中取出 3 个。

使用

map(lambda x,y: y)
剥离索引并恢复原始样本。

此示例实现了 75/25 的分割。

x % 5 == 0
x % 5 != 0
给出 80/20 的分割。

如果你真的想要 70/30 的分割,

x % 10 < 3
x % 10 >= 3
应该可以。

更新:

从 TensorFlow 2.0.0 开始,由于 AutoGraph 的限制,上述代码可能会导致一些警告。要消除这些警告,请单独声明所有 lambda 函数:

def is_test(x, y):
    return x % 4 == 0

def is_train(x, y):
    return not is_test(x, y)

recover = lambda x,y: y

test_dataset = all.enumerate() \
                    .filter(is_test) \
                    .map(recover)

train_dataset = all.enumerate() \
                    .filter(is_train) \
                    .map(recover)

这在我的机器上没有发出警告。而让

is_train()
成为
not is_test()
绝对是一个很好的做法。


0
投票

我将首先解释为什么接受的答案是错误的,其次将提供一个简单的工作解决方案,使用

take()
skip()
seed

使用 TF/Torch 数据集等管道时,谨防惰性评估。避免:

# DONT
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)

因为 take 和skip 将同步到单个洗牌,而是分别作为

shuffle+take
shuffle+skip
执行(是的!),通常在 80%*20%=16% 的情况下重叠。所以,信息泄露

如有疑问,请使用此代码

import tensorflow as tf

def gen_data():
    return iter(range(10))

full_dataset = tf.data.Dataset.from_generator(
  gen_data, 
  output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))

train_size = 8

# WRONG WAY
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)

A = set(train_dataset.as_numpy_iterator())
B = set(test_dataset.as_numpy_iterator())

# EXPECT OVERLAP
assert A.intersection(B)==set()

print(list(A))
print(list(B))

现在,有效的是在训练和测试数据集中重复和播种洗牌,这也有利于再现性。这应该适用于任何确定性排序迭代器:

import tensorflow as tf def gen_data(): return iter(range(10)) ds = tf.data.Dataset.from_generator( gen_data, output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element")) SEED = 42 # NOTE: change this ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100) ds_test = ds.shuffle(100,seed=SEED).skip(8) A = set(ds_train.as_numpy_iterator()) B = set(ds_test.as_numpy_iterator()) assert A.intersection(B)==set() print(list(A)) print(list(B))
通过使用 

SEED

,您可以检查/估计泛化(引导代替交叉验证)。

© www.soinside.com 2019 - 2024. All rights reserved.