在TF 2.0中使用tf.keras,如何定义依赖于学习阶段的自定义层?

问题描述 投票:0回答:2

我想使用tf.keras构建自定义图层。为简单起见,假设它应该在训练期间返回输入* 2并在测试期间输入* 3。这样做的正确方法是什么?

我试过这种方法:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training:
            return inputs*2
        else:
            return inputs*3

然后我可以像这样使用这个类:

>>> layer = CustomLayer()
>>> layer(10)
tf.Tensor(30, shape=(), dtype=int32)
>>> layer(10, training=True)
tf.Tensor(20, shape=(), dtype=int32)

它工作正常!但是,当我在模型中使用这个类,并且我调用它的fit()方法时,似乎training没有设置为True。我尝试在call()方法的开头添加以下代码,但training始终等于0。

if training is None:
    training = K.learning_phase()

我错过了什么?

编辑

我找到了一个解决方案(请参阅我的回答),但我仍在寻找使用@tf.function的更好的解决方案(我更喜欢亲笔签名到这个smart_cond()业务)。不幸的是,看起来K.learning_phase()@tf.function不相称(我的猜测是当call()函数被跟踪时,学习阶段被硬编码到图中:因为这发生在调用fit()方法之前,学习阶段是总是0)。这可能是一个错误,或者在使用@tf.function时可能还有另一种方法可以进入学习阶段。

tensorflow keras tf.keras
2个回答
1
投票

FrançoisChollet确认使用@tf.function时的正确解决方案是:

class CustomLayer(Layer):
    @tf.function
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        if training:
            return inputs * 2
        else:
            return inputs * 3

目前有一个错误(截至2019年2月15日),使training总是等于0,但这很快就会修复。


0
投票

以下代码不使用@tf.function,因此它看起来不太好(因为它不使用签名),但它工作正常:

from tensorflow.python.keras.utils.tf_utils import smart_cond

class CustomLayer(Layer):
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        return smart_cond(training, lambda: inputs * 2, lambda: inputs * 3)
© www.soinside.com 2019 - 2024. All rights reserved.