拟合 LSTM 编码器解码器时出现不兼容的形状错误

问题描述 投票:0回答:0

我正在使用以下代码创建用于信号预测的 LSTM 编码器解码器。

def create_model_ED(numberOfLSTMunits, batch_size, n_timesteps_in, n_features):

  encoder_inputs = Input(shape=(n_timesteps_in, n_features), name='encoder_inputs')
  encoder_lstm = LSTM(numberOfLSTMunits, return_state=True,  name='encoder_lstm')
  encoder_outputs, state_h, state_c = encoder_lstm(encoder_inputs)
  
  #context vector 
  states = [state_h, state_c]
  
  # Decoder receives 1 token, outputs 1 token at a time 
  
  decoder_lstm = LSTM(numberOfLSTMunits, return_sequences=True, return_state=True, name='decoder_lstm')
  decoder_dense = Dense(n_features, activation='tanh',  name='decoder_dense')

  all_outputs = []

  decoder_input_data = np.zeros((batch_size, 1, n_features))
  decoder_input_data[:, 0, 0] = 0.2
  
  inputs = decoder_input_data

  for _ in range(n_timesteps_in):

      # Run the decoder on one time step
      outputs, state_h, state_c = decoder_lstm(inputs, initial_state=states)
      outputs = decoder_dense(outputs)

      all_outputs.append(outputs)

      # Reinject the outputs as inputs for the next loop iteration
      # as well as update the states

      inputs = outputs  
      states = [state_h, state_c]

  # Concatenate all predictions such as [batch_size, timesteps, features]
  decoder_outputs = Lambda(lambda x: K.concatenate(x, axis=1))(all_outputs)

  # Define and compile model 
  model = Model(encoder_inputs, decoder_outputs, name='model_encoder_decoder')
  model.compile(optimizer='adam', loss='mean_squared_error', metrics=[tf.keras.metrics.RootMeanSquaredError()])

  return model

但是,当我建立和拟合模型时,我得到了给定的错误,

batch_size = 6 model_encoder_decoder=create_model_ED(100, batch_size, 100, 1)

model_encoder_decoder.fit(X_train, y_train, batch_size=batch_size, epochs=2, validation_split=0.2)

错误:

............

lstm.py", line 967, in step
      z += backend.dot(h_tm1, recurrent_kernel)
Node: 'while/add'
Incompatible shapes: [6,400] vs. [4,400]
     [[{{node while/add}}]]
     [[model_encoder_decoder/decoder_lstm/PartitionedCall]] [Op:__inference_train_function_746223]

有人可以帮我解决我在这里犯的错误吗?

我找不到任何解决方案来尝试检查。

lstm recurrent-neural-network encoder decoder activation
© www.soinside.com 2019 - 2024. All rights reserved.