使用 image_dataset_from_directory 时是否可以将张量流数据集拆分为训练、验证和测试数据集?

问题描述 投票:0回答:3

我正在使用

tf.keras.utils.image_dataset_from_directory
加载包含 4575 张图像的数据集。虽然此函数允许将数据拆分为两个子集(使用
validation_split
参数),但我想将其拆分为训练、测试和验证子集。

我尝试使用

dataset.skip()
dataset.take()
进一步分割结果子集之一,但这些函数分别返回
SkipDataset
TakeDataset
(顺便说一句,与 文档 相反,它是声称这些函数返回一个
Dataset
)。这会导致拟合模型时出现问题 - 在验证集上计算的指标(val_loss、val_accuracy)从模型历史记录中消失。

所以,我的问题是:有没有办法将

Dataset
分成三个子集进行训练、验证和测试,以便所有三个子集也是
Dataset
对象?

用于加载数据的代码

def load_data_tf(data_path: str, img_shape=(256,256), batch_size: int=8):
    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_path,
        validation_split=0.2,
        subset="training",
        label_mode='categorical',
        seed=123,
        image_size=img_shape,
        batch_size=batch_size)
    val_ds = tf.keras.utils.image_dataset_from_directory(
        data_path,
        validation_split=0.3,
        subset="validation",
        label_mode='categorical',
        seed=123,
        image_size=img_shape,
        batch_size=batch_size)
    return train_ds, val_ds

train_dataset, test_val_ds = load_data_tf('data_folder', img_shape = (256,256), batch_size=8)
test_dataset = test_val_ds.take(686)
val_dataset = test_val_ds.skip(686)

模型编译与拟合

model.compile(optimizer='sgd',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])
history = model.fit(train_dataset, epochs=50, validation_data=val_dataset, verbose=1)

使用普通

Dataset
时,
val_accuracy
val_loss
存在于模型历史记录中:

但是当使用

SkipDataset
时,它们不是:

python tensorflow keras tensorflow-datasets tf.keras
3个回答
4
投票

问题是,当您执行

test_val_ds.take(686)
test_val_ds.skip(686)
时,您并不是在采集和跳过样本,而是实际上是在进行批次。尝试运行
print(val_dataset.cardinality())
,您将看到您确实保留了多少批次用于验证。我猜测
val_dataset
是空的,因为您没有 686 个批次进行验证。这是一个工作示例:

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

test_dataset = val_ds.take(5)
val_ds = val_ds.skip(5)

print('Batches for testing -->', test_dataset.cardinality())
print('Batches for validating -->', val_ds.cardinality())

model = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1./255, input_shape=(180, 180, 3)),
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

epochs=1
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=1
)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
Batches for testing --> tf.Tensor(5, shape=(), dtype=int64)
Batches for validating --> tf.Tensor(18, shape=(), dtype=int64)
92/92 [==============================] - 96s 1s/step - loss: 1.3516 - accuracy: 0.4489 - val_loss: 1.1332 - val_accuracy: 0.5645

在此示例中,

batch_size
为32,您可以清楚地看到验证集保留了23个批次。之后,将 5 批分配给测试集,剩下 18 批作为验证集。


1
投票

我无法发表评论,所以必须回答JeffreyShran,关于我们如何确定

take
skip
在该街区拍摄相同的照片。这是检查代码:

dataset = tf.data.Dataset.range(10)
take = int(len(dataset)/2)

test = dataset.take(take)
print('test:', list(test.as_numpy_iterator()))
dataset = dataset.skip(take)
print('valid:', list(dataset.as_numpy_iterator()))

我们得到:

test: [0, 1, 2, 3, 4]
valid: [5, 6, 7, 8, 9]

我是个新人,所以如果我写的不合适的地方,我很抱歉。但我想上面的考虑一定已经被证明了。


0
投票

用这个程序训练之后,我们如何制作混淆矩阵?

谢谢

© www.soinside.com 2019 - 2024. All rights reserved.