如何在tensorflow中加载自己的数据集来制作DCGAN?

问题描述 投票:0回答:1

我试着按照教程去做 此处 来创建一个DCGAN,只改变了我的输入数据,我无法将我的数据加载到合适的形状,以便我的模型进行训练。这是我的代码。

import pathlib
data_dir = pathlib.Path(data_dir)

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=1)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
def process_path(file_path):
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img
labeled_ds = list_ds.map(process_path)

从那里,我几乎复制了教程, 修改参数,以满足我的数据集, 我得到这个最终结果,

train(labeled_ds, EPOCHS)

但当我这样做时,我得到了以下错误。

ValueError: Input 0 of layer sequential_1 is incompatible with the layer: expected ndim=4, found ndim=3. Full shape received: [224, 320, 1]

我怎样才能给tensorflow数据集增加一个额外的维度?或者说,相对于教程中MNIST数据集的加载方式,我是不是做错了什么?谢谢你

python tensorflow mnist gan
1个回答
0
投票

看起来你缺少了批量大小(第0维)。

最简单和最灵活的方法是使用 .batch() 在你 tf.data.Dataset:

batch_size = 1  # Or something else.
labeled_ds = list_ds.map(process_path).batch(batch_size)

这应该修改张量的形状,从 [height, width, channels][batch_size, height, width, channels] 这是模型所期望的。希望以上内容对你有用!

© www.soinside.com 2019 - 2024. All rights reserved.