sequential8 层的输入 0 与该层不兼容:预期输入形状的轴 -1 值为 1,但收到的输入形状为 (32, 1, 21)

问题描述 投票:0回答:1

我正在使用 Keras 构建 CNN,但它显示错误。代码如下:

input_shape = (21,)  # For your tabular data, add a channel dimension of 1

model = models.Sequential()

# Convolutional layers
model.add(layers.Conv1D(32, 3, activation='relu', input_shape=(21, 1)))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))

# Flatten the output and add dense layers
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))  # 10 output units for 10 classes


model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(X_train_reshaped, Y_train, epochs=10, validation_data=(X_test, Y_test))

# Predict using the model
y_pred = model.predict(X_test_reshaped)
y_pred_classes = np.argmax(y_pred, axis=1)

输入有 500 万行,有 21 个特征。该代码应该对 10 个不同的类别进行多类别分类。代码有什么问题吗?

感谢您的帮助。

keras conv-neural-network sequential
1个回答
0
投票

您的

X_train_reshaped
的形状是
(32, 1, 21)
,但您已使用
input_shape=(21, 1)
定义了模型。因此,您需要将
X_train_reshaped
重塑为形状
(32, 21, 1)
以匹配模型的输入形状。

这里有一个简单的例子供您参考:

model = models.Sequential()

# Convolutional layers
model.add(layers.Conv1D(32, 3, activation='relu', input_shape=(21, 1)))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))
model.add(layers.MaxPooling1D(2))
model.add(layers.Conv1D(64, 3, activation='relu'))

# Flatten the output and add dense layers
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))  # 10 output units for 10 classes


model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
X_train_reshaped = tf.random.normal([32,21,1])
Y_train = tf.random.normal([32,10])
X_test = tf.random.normal([32,21,1])
Y_test = tf.random.normal([32,10])
# Train the model
model.fit(X_train_reshaped, Y_train, epochs=10, validation_data=(X_test, Y_test))
© www.soinside.com 2019 - 2024. All rights reserved.