Keras模型摘要不正确

问题描述 投票:0回答:1

我正在使用]进行数据扩充>

data_gen=image.ImageDataGenerator(rotation_range=20,width_shift_range=0.2,height_shift_range=0.2,
                                  zoom_range=0.15,horizontal_flip=False)

iter=data_gen.flow(X_train,Y_train,batch_size=64)

data_gen.flow()需要4级数据矩阵,因此X_train的形状为(60000, 28, 28, 1)。在定义模型的架构时,我们需要传递相同的形状,即(60000, 28, 28, 1),如下所示;

model=Sequential()
model.add(Dense(units=64,activation='relu',kernel_initializer='he_normal',input_shape=(28,28,1)))
model.add(Flatten())    
model.add(Dense(units=10,activation='relu',kernel_initializer='he_normal'))
model.summary()

model.add(Flatten())用于处理等级2问题。现在问题出在model.summary()上。它给出了错误的输出,如下所示;

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 28, 28, 64)        128       
_________________________________________________________________
flatten_1 (Flatten)          (None, 50176)             0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                501770    
=================================================================
Total params: 501,898
Trainable params: 501,898
Non-trainable params: 0

Output Shapedense_1 (Dense)应该是(None,64)Param #应该是(28*28*64)+64,即50240Output Shapedense_2 (Dense)是正确的,但Param #应为(64*10)+10,即650

为什么会这样,如何解决这个问题?

我正在使用data_gen = image.ImageDataGenerator(rotation_range = 20,width_shift_range = 0.2,height_shift_range = 0.2,zoom_range = 0.15,horizo​​ntal_flip = ...]进行数据增强]

python-3.x machine-learning keras deep-learning keras-layer
1个回答
0
投票

摘要不正确。 keras Dense层始终在输入的最后一个维度上起作用。

ref:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense

© www.soinside.com 2019 - 2024. All rights reserved.