(train_dataset,validation_dataset,test_dataset) = tfds.load('fashion_mnist',
with_info=True, as_supervised=True,
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'])
我正在尝试将fashion_mnist分为3组训练和验证。我不确定这里是什么错误,因为我根本无法解决它。
“ fashion_mnist”数据集在Tensorflow数据集中仅包含训练和测试拆分(请参见documentation,拆分部分),因此在split
参数中,它期望列表的长度最大为2,但是您正在使用长度为3的列表。为了进行训练,验证和测试拆分,可以执行以下操作:
whole_ds = tfds.load("fashion_mnist", with_info = True, split='train+test', as_supervised=True)
n = tf.data.experimental.cardinality(whole_ds) # 70 000
train_num = int(n*0.8)
val_num = int(n*0.1)
train_ds = whole_ds.take(train_num)
val_ds = whole_ds.skip(train_num).take(val_num)
test_ds = whole_ds.skip(train_num+val_num)