如何在Tensorflow 2.0中使用gradient_override_map?

问题描述 投票:2回答:1

我正在尝试使用Qazxswpoi和Tensorflow 2.0。有一个gradient_override_map,我也将在这里作为例子。

在2.0中,example in the documentation可用于计算梯度,如下所示:

GradientTape

还有import tensorflow as tf print(tf.version.VERSION) # 2.0.0-alpha0 x = tf.Variable(5.0) with tf.GradientTape() as tape: s_1 = tf.square(x) print(tape.gradient(s_1, x)) 装饰器,可用于定义新函数的渐变(同样,使用tf.custom_gradient):

example from the docs

但是,我想替换标准函数的梯度,如import tensorflow as tf print(tf.version.VERSION) # 2.0.0-alpha @tf.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.math.log(1 + e), grad x = tf.Variable(100.) with tf.GradientTape() as tape: y = log1pexp(x) print(tape.gradient(y, x)) 。我试着使用以下代码:

tf.square

但是,有两个问题:渐变替换似乎不起作用(它被评估为@tf.RegisterGradient("CustomSquare") def _custom_square_grad(op, grad): return tf.constant(0) with tf.Graph().as_default() as g: x = tf.Variable(5.0) with g.gradient_override_map({"Square": "CustomSquare"}): with tf.GradientTape() as tape: s_2 = tf.square(x, name="Square") with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(tape.gradient(s_2, x))) 而不是10.0),我需要求助于0.0来执行图形。有没有办法在“原生”TensorFlow 2.0中实现这一目标?

在TensorFlow 1.12.0中,以下内容产生所需的输出:

session.run()
python tensorflow tensorflow2.0
1个回答
3
投票

TensorFlow 2.0中没有内置机制来覆盖范围内内置运算符的所有渐变。但是,如果您能够为每次调用内置运算符修改调用站点,则可以使用import tensorflow as tf print(tf.__version__) # 1.12.0 @tf.RegisterGradient("CustomSquare") def _custom_square_grad(op, grad): return tf.constant(0) x = tf.Variable(5.0) g = tf.get_default_graph() with g.gradient_override_map({"Square": "CustomSquare"}): s_2 = tf.square(x, name="Square") grad = tf.gradients(s_2, x) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(grad)) 装饰器,如下所示:

tf.custom_gradient
© www.soinside.com 2019 - 2024. All rights reserved.