我有一个基于一个 .tfrecord 文件的张量流数据集。如何将数据集拆分为测试数据集和训练数据集?例如。 70% 训练,30% 测试?
编辑:
我的张量流版本:1.8 我已经检查过,没有可能的重复项中提到的“split_v”函数。我也在使用 tfrecord 文件。
您可以使用
Dataset.take()
和 Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
为了更通用,我给出了一个使用 70/15/15 训练/验证/测试分割的示例,但如果您不需要测试或验证集,只需忽略最后两行。
拿走:
创建一个数据集,其中最多包含此数据集中的 count 个元素。
跳过:
创建一个数据集,跳过该数据集中的 count 元素。
Dataset.shard()
:
创建一个仅包含该数据集 1/num_shards 的数据集。
使用
take()
和 skip()
需要知道数据集大小。如果我不知道或者不想知道怎么办?使用
shard()
仅给出数据集的1 / num_shards
。如果我想要剩下的怎么办?我尝试在下面提出一个更好的解决方案,仅在 TensorFlow 2 上进行了测试。假设您已经有一个 shuffled 数据集,则可以使用
filter()
将其分成两部分:
import tensorflow as tf
all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \
.shuffle(10, reshuffle_each_iteration=False)
test_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 == 0) \
.map(lambda x,y: y)
train_dataset = all.enumerate() \
.filter(lambda x,y: x % 4 != 0) \
.map(lambda x,y: y)
for i in test_dataset:
print(i)
print()
for i in train_dataset:
print(i)
参数
reshuffle_each_iteration=False
很重要。它确保原始数据集被打乱一次,不再打乱。否则,两个结果集可能会有一些重叠。
使用
enumerate()
添加索引。
使用
filter(lambda x,y: x % 4 == 0)
从 4 个样本中取出 1 个样本。同样,x % 4 != 0
从 4 个样本中取出 3 个。
使用
map(lambda x,y: y)
剥离索引并恢复原始样本。
此示例实现了 75/25 的分割。
x % 5 == 0
和 x % 5 != 0
给出 80/20 的分割。
如果你真的想要 70/30 的分割,
x % 10 < 3
和 x % 10 >= 3
应该可以。
更新:
从 TensorFlow 2.0.0 开始,由于 AutoGraph 的限制,上述代码可能会导致一些警告。要消除这些警告,请单独声明所有 lambda 函数:
def is_test(x, y):
return x % 4 == 0
def is_train(x, y):
return not is_test(x, y)
recover = lambda x,y: y
test_dataset = all.enumerate() \
.filter(is_test) \
.map(recover)
train_dataset = all.enumerate() \
.filter(is_train) \
.map(recover)
这在我的机器上没有发出警告。而让
is_train()
成为 not is_test()
绝对是一个很好的做法。
我将首先解释为什么接受的答案是错误的,其次将提供一个简单的工作解决方案,使用
take()
,skip()
和seed
。
使用 TF/Torch 数据集等管道时,谨防惰性评估。避免:
# DONT
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
因为 take 和skip 将同步到单个洗牌,而是分别作为
shuffle+take
和 shuffle+skip
执行(是的!),通常在 80%*20%=16% 的情况下重叠。所以,信息泄露。
如有疑问,请使用此代码
import tensorflow as tf
def gen_data():
return iter(range(10))
full_dataset = tf.data.Dataset.from_generator(
gen_data,
output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))
train_size = 8
# WRONG WAY
full_dataset = full_dataset.shuffle(10)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
A = set(train_dataset.as_numpy_iterator())
B = set(test_dataset.as_numpy_iterator())
# EXPECT OVERLAP
assert A.intersection(B)==set()
print(list(A))
print(list(B))
现在,有效的是在训练和测试数据集中重复和播种洗牌,这也有利于再现性。这应该适用于任何确定性排序迭代器:
import tensorflow as tf
def gen_data():
return iter(range(10))
ds = tf.data.Dataset.from_generator(
gen_data,
output_signature=tf.TensorSpec(shape=(),dtype=tf.int32,name="element"))
SEED = 42 # NOTE: change this
ds_train = ds.shuffle(100,seed=SEED).take(8).shuffle(100)
ds_test = ds.shuffle(100,seed=SEED).skip(8)
A = set(ds_train.as_numpy_iterator())
B = set(ds_test.as_numpy_iterator())
assert A.intersection(B)==set()
print(list(A))
print(list(B))
通过使用 SEED
,您可以检查/估计泛化(引导代替交叉验证)。