如何减少数据的维度,通过ImageDataGenerator的flow_from_directory函数加载?

问题描述 投票:0回答:1

因为我从结构化文件夹中加载数据(图片),所以我使用了图像数据生成器的 flow_from_directory 的作用 ImageDataGenerator 类,它是由 Keras. 当我把这些数据输入到一个系统时,我没有任何问题。CNN 模型。但当涉及到一个 LSTM 模型,得到以下错误。ValueError: Error when checking input: expected lstm_1_input to have 3 dimensions, but got array with shape (64, 28, 28, 1). 如何在读取输入数据时减少数据的维度?ImageDataGenerator 对象,以便能够使用 LSTM 模型,而非 CNN?

p.s.输入图像的形状是? (28, 28) 而它们 grayscale.

train_valid_datagen = ImageDataGenerator(validation_split=0.2)

train_gen = train_valid_datagen.flow_from_directory(
    directory=TRAIN_IMAGES_PATH,
    target_size=(28, 28),
    color_mode='grayscale',
    batch_size=64,
    class_mode='categorical',
    shuffle=True,
    subset='training'
)

更新。 LSTM模型代码。

inp = Input(shape=(28, 28, 1))
inp = Lambda(lambda x: squeeze(x, axis=-1))(inp)  # from 4D to 3D
x = LSTM(num_units, dropout=dropout, recurrent_dropout=recurrent_dropout, activation=activation_fn, return_sequences=True)(inp)
x = BatchNormalization()(x)
x = Dense(128, activation=activation_fn)(x)
output = Dense(nb_classes, activation='softmax', kernel_regularizer=l2(0.001))(x)

model = Model(inputs=inp, outputs=output)
tensorflow keras deep-learning lstm
1个回答
1
投票

你开始用4D数据(比如你的图像)来喂养你的网络,以便与你的网络兼容。ImageDataGenerator 然后你必须将它们以3D格式重塑为LSTM。

这些都是可能性。

在只有一个通道的情况下,你可以简单地压缩最后一个维度。

inp = Input(shape=(28, 28, 1))
x = Lambda(lambda x: tf.squeeze(x, axis=-1))(inp) # from 4D to 3D
x = LSTM(32)(x)

如果你有多个通道(这是RGB图像的情况,或者如果你想在Conv2D之后应用RNN),一个解决方案可以是这样。

inp = Input(shape=(28, 28, 1))
x = Conv2D(32, 3, padding='same', activation='relu')(inp)
x = Reshape((28,28*32))(x)  # from 4D to 3D
x = LSTM(32)(x)

拟合度可以像以往一样用 model.fit_generator


更新:模型审查

inp = Input(shape=(28, 28, 1))
x = Lambda(lambda x: squeeze(x, axis=-1))(inp)  # from 4D to 3D
x = LSTM(32, dropout=dropout, recurrent_dropout=recurrent_dropout, activation=activation_fn, return_sequences=False)(x)
x = BatchNormalization()(x)
x = Dense(128, activation=activation_fn)(x)
output = Dense(nb_classes, activation='softmax', kernel_regularizer=l2(0.001))(x)

model = Model(inputs=inp, outputs=output)
model.summary()

在定义inp变量的时候要注意(不要覆盖)。

在LSTM中设置return_seq = False,以便有2D输出。

© www.soinside.com 2019 - 2024. All rights reserved.