除了在Keras序列到序列模型

问题描述 投票:1回答:1

我想建立使用LSTM和密集的神经网络Keras一个序列序列模型。所述编码器编码的输入,将编码的状态,然后输入被连接在一起,并送入解码器,其是一种LSTM +密集的神经网络在时间上输出分类标签。下面是我的代码看起来像

from keras.utils import to_categorical
from keras.layers import Embedding, Bidirectional, GRU, Dense, TimeDistributed, LSTM, Input, Lambda
from keras.models import  Sequential, Model
import numpy as np
from keras import preprocessing
import keras

encoder_inputs_seq = Input(shape=(114,))
encoder_inputs = Embedding(input_dim= 1000 + 1, output_dim = 20)(encoder_inputs_seq)

x, state_h, state_c = LSTM(32, return_state=True)(encoder_inputs)
states = [state_h, state_c]

decoder_lstm = LSTM(32, return_sequences=True, return_state=True)
decoder_dense = Dense(9, activation='softmax')

all_outputs = []

input_state = keras.layers.RepeatVector(1)(state_h)


for i in range(5):
    # Run the decoder on one timestep
    new_input = keras.layers.concatenate([input_state, keras.layers.RepeatVector(1)(encoder_inputs[:, 1, :])], axis = -1)

    outputs, state_h, state_c = decoder_lstm(new_input,
                                             initial_state=states)
    outputs = decoder_dense(outputs)
    # Store the current prediction (we will concatenate all predictions later)
    all_outputs.append(outputs)
    # Reinject the outputs as inputs for the next loop iteration
    # as well as update the states
    states = [state_h, state_c]
    input_state = keras.layers.RepeatVector(1)(state_h)

decoder_outputs = Lambda(lambda x: keras.layers.concatenate(x, axis=1))(all_outputs)

model = Model(encoder_inputs_seq, decoder_outputs)

model.summary()

我碰到以下情况例外

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

我要去哪里错了吗?

keras keras-layer keras-2 sequence-to-sequence
1个回答
0
投票

问题是,你没有一个lambda层包裹它切片张量(encoder_inputs[:, 1, :])。您在Keras模型做的每个操作都将在一个层。您可以通过更换for循环具有以下内您的第一行代码解决这个问题:

slice = Lambda(lambda x: x[:, 1, :])(encoder_inputs)
new_input = keras.layers.concatenate(
    [input_state, keras.layers.RepeatVector(1)(slice)], 
    axis = -1)
© www.soinside.com 2019 - 2024. All rights reserved.