我目前正在构建一个多模态情感识别模型,我试图在下面使用自定义类添加一个注意力机制:
class Attention(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
self.We = self.add_weight(shape=(input_shape[-1], 1), initializer='random_normal', trainable=True)
self.b = self.add_weight(shape=(input_shape[1],1), initializer='zeros', trainable=True)
super(Attention, self).build(input_shape)
def call(self, x):
q = tf.nn.tanh(tf.linalg.matmul(x, self.We) + self.b)
a = tf.nn.softmax(q, axis=1)
return tf.reduce_sum(a * x, axis=1)
这个类在 lstm 模型中使用:
self.features_audio_dim = self.train_x_audio.shape[2] #1611
audio_input = Input(shape=(self.sequence_length, self.features_audio_dim), dtype='float32')
lstm_audio = LSTM(128, return_sequences=True,dropout=0.3,recurrent_dropout=0.2)(audio_input)
attention_audio = Attention()(lstm_audio)
我试图修复错误,但无济于事,问题出在注意力层
ValueError:调用层“attention_8”(注意类型)时遇到异常。
注意层必须在输入列表上调用,即[查询,值]或[查询,值,键]。收到:Tensor("Placeholder:0", shape=(None, 33, 128), dtype=float32).