我有一个顺序的keras模型,并且那里有一个类似于下面名为'CounterLayer'的示例的自定义图层。我正在使用tensorflow 2.0(渴望执行)
class CounterLayer(tf.keras.layers.Layer):
def __init__(self, stateful=False,**kwargs):
self.stateful = stateful
super(CounterLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.count = tf.keras.backend.variable(0, name="count")
super(CounterLayer, self).build(input_shape)
def call(self, input):
updates = []
updates.append((self.count, self.count+1))
self.add_update(updates)
tf.print('-------------')
tf.print(self.count)
return input
例如,当我运行epoch = 5之类的东西时,self.count
的值不会随着每次运行而更新。它始终保持不变。我从https://stackoverflow.com/a/41710515/10645817获得了此示例。我需要几乎与此类似的东西,但我想知道这项工作是否在急于执行tensorflow或我需要做些什么才能获得预期的输出。
我一直在尝试实现这一点,但无法弄清楚。有人可以帮我吗。谢谢...
是,我的问题解决了。我遇到了一些内置方法来更新这种变量(这是为了在各个时期之间保持持久状态,如上述案例)。基本上我需要做的例如:
def build(self, input_shape):
self.count = tf.Variable(0, dtype=tf.float32, trainable=False)
super(CounterLayer, self).build(input_shape)
def call(self, input):
............
self.count.assign_add(1)
............
return input
一个可以用来计算call
函数中的更新值,也可以通过调用self.count.assign(some_updated_value)
来分配它。有关此类操作的详细信息,请参见https://www.tensorflow.org/api_docs/python/tf/Variable。谢谢。