我正在尝试将生成器包装到tf.data.Dataset中(只是为了了解这一点)。这是我的片段。希望有人能发现我做错了什么。
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
gen = img_gen.flow_from_directory(data_path, target_size=(224, 224), batch_size=32)
dataset = tf.data.Dataset.from_generator(
lambda: gen,
output_types = (tf.float32, tf.float32),
output_shapes = ([32, 224, 224, 3], [32, 6]),
)
model.fit(dataset,
steps_per_epoch = gen.n // 32,
epochs=10)
[ValueError:generator
产生了形状为(11,224,224,3)的元素,其中期望形状为(32,224,224,3)的元素。
如果我更改了此设置,我“出现”来解决此问题:
dataset = tf.data.Dataset.from_generator(
lambda: gen,
output_types = (tf.float32, tf.float32),
# output_shapes = ([32, 224, 224, 3], [32, 6]),
output_shapes = ([None, 224, 224, 3], [None, 6]),
)
即而不是在这段代码中显式的批处理大小为32,而是将它替换为None。它显然不再抱怨(11,224,224,3),剩余的批处理位。但是,是否有一种方法可以使数据集即使在最后也可以采样32?只需翻到火车的“开始”位置即可。我还是有些可疑。
我从.fit开始,显然它正在运行,损失减少且准确性在每个时期都有提高。
如果有人有更好的方法或解释,请告诉我。