keras访问预先训练模型的访问层参数以进行冻结

问题描述 投票:1回答:2

我保存了一个多层的LSTM。现在,我想加载它并调整最后一个LSTM层。如何定位此图层并更改其参数?

训练和保存的简单模型示例:

model = Sequential()
# first layer  #neurons 
model.add(LSTM(100, return_sequences=True, input_shape=(X.shape[1], 
X.shape[2])))
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(25))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')

我可以加载和重新训练它,但我找不到一种方法来定位特定的图层并冻结所有其他图层。

python machine-learning keras lstm keras-layer
2个回答
0
投票

如果您之前已经构建并保存了模型,并且现在想要加载它并仅微调最后一个LSTM图层,那么您需要将其他图层的trainable属性设置为False。首先,使用model.summary()方法找到图层的名称(或从顶部开始从零开始计算图层的索引)。例如,这是我的一个模型的输出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_10 (InputLayer)        (None, 400, 16)           0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 400, 32)           4128      
_________________________________________________________________
lstm_2 (LSTM)                (None, 32)                8320      
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 33        
=================================================================
Total params: 12,481
Trainable params: 12,481
Non-trainable params: 0
_________________________________________________________________

然后将除LSTM图层以外的所有图层的可训练参数设置为False

方法1:

for layer in model.layers:
    if layer.name != `lstm_2`
        layer.trainable = False

方法2:

for layer in model.layers:
    layer.trainable = False

model.layers[2].trainable = True  # set lstm to be trainable

# to make sure 2 is the index of the layer
print(model.layers[2].name)    # prints 'lstm_2'

不要忘记再次编译模型以应用这些更改。


1
投票

一个简单的解决方案是命名每一层,即

model.add(LSTM(50, return_sequences=True, name='2nd_lstm'))

然后,在加载模型时,您可以迭代图层并冻结与名称条件匹配的图层:

for layer in model.layers:
    if layer.name == '2nd_lstm':
        layer.trainable = False

然后,您需要重新编译模型才能使更改生效,之后您可以像往常一样恢复培训。

© www.soinside.com 2019 - 2024. All rights reserved.