背景:我的数据形状为(batch_size、ghost_dim、data_length),我正在学习训练基于 MultiHeadAttention 的模型。我使用 model()、model.predict() 或 model.predict_on_batch() 函数输入一批相同维度的数据,以在拟合模型后返回预测数组,并且每个数据都出现类似的形状相关错误。
投掷:
2 frames
/usr/local/lib/python3.10/dist-packages/keras/engine/training.py in tf__predict_function(iterator)
13 try:
14 do_return = True
---> 15 retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
16 except:
17 do_return = False
ValueError: in user code:
File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2169, in predict_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2155, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2143, in run_step **
outputs = model.predict_step(data)
File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2111, in predict_step
return self(x, training=False)
File "/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
**ValueError: Exception encountered when calling layer 'layer_normalization_1' (type LayerNormalization).
Cannot take the length of shape with unknown rank.
Call arguments received by layer 'layer_normalization_1' (type LayerNormalization):
• inputs=tf.Tensor(shape=<unknown>, dtype=float16)**
如果我直接使用 model() 函数调用,则 MultiHeadAttention 层中的输入形状会出现类似的错误,输入看起来是 1 维,而不是我输入的原始单批数据,即:
x_predict = np.random.normal(size=(1, 1, 512))
predictions = self.model(x_predict)
predictions = self.model.predict_on_batch(x_predict)
不确定所有这些 numpy 数组形状错误的根源是什么。有任何帮助/想法值得赞赏吗?
问题在于指定输入层形状中的批次:
decoder_inputs = layers.Input(batch_input_shape=shape, dtype=tensorflow.float16)
那么你的模型将始终假设你的训练和测试数据将由定义的batch_size = 32形成。所以batch_size = 1的x_predict会给出错误。我建议从输入层中删除batch_size:
decoder_inputs = layers.Input(shape=(ghost_dim, embed_dim), dtype=tensorflow.float16)
或者,如果您可以对 x_predict 数据进行采样,使其始终具有 (32, 1, 512) 的形状。