ValueError:维度必须等于 ResNet-50 迁移学习 TF

问题描述 投票:0回答:1

我正在尝试在 keras 中微调 ResNet-50,以实现 wikiart 数据集上的艺术作品风格分类器。 我有一个具有以下形状的训练和测试数据集:

 <_MapDataset element_spec=(TensorSpec(shape=(32, 224, 224, 3), dtype=tf.int64, name=None), TensorSpec(shape=(32,), dtype=tf.int64, name=None))>)

这是模型:

tf_input = tf.keras.layers.Input(shape=(img_height, img_height, 3))
base_model = tf.keras.applications.resnet50.ResNet50(input_tensor=tf_input, include_top=False)
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
predictions = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
base_learning_rate = 0.001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=[tf.keras.metrics.Accuracy(name='accuracy')])

当我尝试使用

model.fit(train, epochs=epochs, validation_data=test)
训练它时,出现以下错误:

ValueError: Dimensions must be equal, but are 32 and 27 for '{{node Equal}} = Equal[T=DT_FLOAT, incompatible_shape_error=true](Cast_1, functional_49_1/dense_43_1/Softmax)' with input shapes: [32], [32,27].

我真的迷失了,任何帮助将不胜感激。

tensorflow keras deep-learning transfer-learning
1个回答
0
投票

当您使用 SparseCategoricalCrossentropy 作为损失时,我假设您将标签作为一维缩放器数组。

但是您还应该更改准确性指标以进行此类预测,即您应该将代码更改为:

tf_input = tf.keras.layers.Input(shape=(img_height, img_height, 3))
base_model = tf.keras.applications.resnet50.ResNet50(input_tensor=tf_input, include_top=False)
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
predictions = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
model = tf.keras.Model(inputs=base_model.input, outputs=predictions)
base_learning_rate = 0.001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=[tf.keras.metrics.(name='accuracy')])
© www.soinside.com 2019 - 2024. All rights reserved.