在kersa中添加注意力机制

问题描述 投票:0回答:0

我目前正在构建一个多模态情感识别模型,我试图在下面使用自定义类添加一个注意力机制:

class Attention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.We = self.add_weight(shape=(input_shape[-1], 1), initializer='random_normal', trainable=True)
        self.b = self.add_weight(shape=(input_shape[1],1), initializer='zeros', trainable=True)
        super(Attention, self).build(input_shape)

    def call(self, x):
        q = tf.nn.tanh(tf.linalg.matmul(x, self.We) + self.b)
        a = tf.nn.softmax(q, axis=1)
        return tf.reduce_sum(a * x, axis=1)

这个类在 lstm 模型中使用:

self.features_audio_dim = self.train_x_audio.shape[2] #1611 
audio_input  = Input(shape=(self.sequence_length, self.features_audio_dim), dtype='float32')
lstm_audio = LSTM(128, return_sequences=True,dropout=0.3,recurrent_dropout=0.2)(audio_input)
attention_audio = Attention()(lstm_audio)

我试图修复错误,但无济于事,问题出在注意力层

ValueError:调用层“attention_8”(注意类型)时遇到异常。

注意层必须在输入列表上调用,即[查询,值]或[查询,值,键]。收到:Tensor("Placeholder:0", shape=(None, 33, 128), dtype=float32).

python keras lstm attention-model
© www.soinside.com 2019 - 2024. All rights reserved.