如何在__call__方法之外为keras-model/layer设置training=False?

问题描述 投票:0回答:1

我使用 Keras 和张量流模型优化 (tf_mot) 进行量化感知训练 (QAT)。我的模型基于 keras.application 的预训练主干。正如迁移学习指南中提到的,我必须使用

x = base_model(inputs, training=False)
。但 tf_mot 不适用于子模型。 https://stackoverflow.com/a/72265777/23370406中提到的解决方案不涉及使用__call__方法,所以我无法将训练模式设置为False。我该怎么办?

子模型版本代码(与tf_mot不兼容):

import keras
from keras import applications, layers, models, utils


inp = layers.Input((None, None, 3))
backbone = applications.vgg16.VGG16(include_top=False,
                                    weights=None)
x = backbone(inp, training=False)

backbone.trainable = False

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10, activation='relu')(x)
out = layers.Dense(1, activation='sigmoid')(x)

model = models.Model(inp, out)

model.summary()

QAT版本代码(与禁用训练模式不兼容):

import keras
from keras import applications, layers, models, utils


inp = layers.Input((None, None, 3))

backbone = applications.vgg16.VGG16(include_top=False,
                                    input_tensor=inp,
                                    weights=None)
x = backbone.output

backbone.trainable = False

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10, activation='relu')(x)
out = layers.Dense(1, activation='sigmoid')(x)

model = models.Model(inp, out)

model.summary()

我已经阅读了Karas源码,但没有找到兼容keras>=3.0.0的解决方案。不幸的是,keras.backend.set_learning_phase 在几个版本前已被弃用。

提前致谢!

tensorflow keras transfer-learning quantization tfmot
1个回答
1
投票

你可以使用

x = backbone.call(inp, training=False)

而不是

x = backbone(inp, training=False)

将各个层放入模型而不是子模型中。在你的例子中(这里有点短)

model.summary()
将从

改变
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, None, None, 3)]   0         
                                                                 
 vgg16 (Functional)          (None, None, None, 512)   14714688  
                                                                 
 global_average_pooling2d (  (None, 512)               0         
 GlobalAveragePooling2D)                                         
                                                                 
 dense (Dense)               (None, 2)                 1026      
                                                                 
=================================================================
Total params: 14715714 (56.14 MB)
Trainable params: 14715714 (56.14 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________                                      

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, None, None, 3)]   0         
                                                                 
 block1_conv1 (Conv2D)       (None, None, None, 64)    1792      
                                                                 
 block1_conv2 (Conv2D)       (None, None, None, 64)    36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, None, None, 64)    0         
                                                                 
 block2_conv1 (Conv2D)       (None, None, None, 128)   73856     
                                                                 
 block2_conv2 (Conv2D)       (None, None, None, 128)   147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, None, None, 128)   0         
                                                                 
 block3_conv1 (Conv2D)       (None, None, None, 256)   295168    
                                                                 
 block3_conv2 (Conv2D)       (None, None, None, 256)   590080    
                                                                 
 block3_conv3 (Conv2D)       (None, None, None, 256)   590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, None, None, 256)   0         
                                                                 
 block4_conv1 (Conv2D)       (None, None, None, 512)   1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, None, None, 512)   2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, None, None, 512)   2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, None, None, 512)   0         
                                                                 
 block5_conv1 (Conv2D)       (None, None, None, 512)   2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, None, None, 512)   2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, None, None, 512)   2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, None, None, 512)   0         
                                                                 
 global_average_pooling2d_1  (None, 512)               0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dense_1 (Dense)             (None, 2)                 1026      
                                                                 
=================================================================
Total params: 14715714 (56.14 MB)
Trainable params: 14715714 (56.14 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


编辑:这是与

Keras>=3.x
兼容的代码(随 TF
2.16
一起提供):

import keras
from keras import applications, layers, models, utils

# load resnet here for testing, because resnet has BatchNormalization layers
backbone = applications.resnet.ResNet50(include_top=False,
                                    weights=None, 
                                    input_shape=(None, None, 3))
backbone.trainable = False

x = layers.GlobalAveragePooling2D()(backbone.output)
x = layers.Dense(10, activation='relu')(x)
out = layers.Dense(1, activation='sigmoid')(x)

model = models.Model(backbone.input, out)

# unfreeze all layers except the BatchNormalization layers
for layer in model.layers:
  if not isinstance(layer, keras.layers.BatchNormalization):
    layer.trainable = True

我认为这甚至比

x = backbone(inp, training=False)
更好,因为通用
training=False
还可以为 dropout 层以及在训练和推理中表现不同的所有其他层启用推理模式。如果您不想要这个,您可以添加例如退出 for 循环中的
isinstance
测试。

© www.soinside.com 2019 - 2024. All rights reserved.