我正在尝试使用卷积网络执行非线性拟合。我想要执行拟合的曲线如下所示:
因此,我将数据组织为一个 Numpy 数组,其中每一行都是一条曲线。每行有 512 个与曲线相对应的数据点(列),加上包含目标的另外两列(513 和 514)。目标是我希望 CNN 生成的两个拟合参数。我的训练数据集中有 19889 行(曲线)。
我的模型是使用 keras 构建的,它看起来像这样:
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3,
activation='relu',data_format="channels_last", input_shape=(512,1)))
model.add(Conv1D(filters=64, kernel_size=3, activation='relu'))
model.add(Dropout(0.5))
model.add(MaxPooling1D(pool_size=2))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(2, activation='linear'))
编译和拟合线看起来像这样
# We compile the keras model
model.compile(loss='mean_squared_error', optimizer= Adam(learning_rate=0.0001), metrics=['mean_squared_error', 'mean_absolute_error','r2_score','root_mean_squared_error'])
# We define the callbacks.
early_stop = EarlyStopping(monitor='val_loss', patience = 25, verbose=2)
log_csv = CSVLogger('Model_2_IRF_extraction_lr00005.csv', separator= ',', append=False)
callbacks_list = [early_stop, log_csv]
# We fit the keras model on the dataset.
history = model.fit(X, y, validation_split=0.1, epochs=500, batch_size=10, verbose=2, callbacks=callbacks_list)
拟合运行,但验证集的 Rsquare 值和验证损失看起来非常奇怪。请看下面。
我的 CNN 有什么问题吗?我怀疑它与第一层中的“input_shape”参数有关,但我不知道该怎么做。
问题似乎确实可能与您如何处理输入形状以及如何为 CNN 准备数据有关。以下是一些需要检查和调整的事项:
X = df_array[:, :512] # Select only the first 512 columns which are the features
X = X.reshape(X.shape[0], 512, 1) # Reshape for Conv1D: (number of samples, timesteps, features)
y = df_array[:, 512:514] # Select the last two columns as targets
数组 3. 数据形状验证:重塑后,您可以打印 X 和 y 的形状以确认它们是否正确:
print(X.shape) # Should be (19889, 512, 1)
print(y.shape) # Should be (19889, 2)
import tensorflow as tf
def r2_score(y_true, y_pred):
SS_res = tf.reduce_sum(tf.square(y_true - y_pred))
SS_tot = tf.reduce_sum(tf.square(y_true - tf.reduce_mean(y_true)))
return (1 - SS_res/(SS_tot + tf.keras.backend.epsilon()))
model.compile(loss='mean_squared_error', optimizer=Adam(learning_rate=0.0001),
metrics=['mean_squared_error', 'mean_absolute_error', r2_score, 'root_mean_squared_error'])