tf.gradient表现得像tfp.math.diag_jacobian

问题描述 投票:1回答:2

我尝试使用输入数据中的损失函数的梯度来计算输入数据的噪声:

my_grad = tf.gradients(丢失,输入)

loss是一个大小的数组(n x 1),其中n是数据集的数量,m是数据集的大小,input是(n x m)的数组,其中m是单个数据集的大小。

我需要my_grad的大小(n x m) - 所以对于每个数据集,梯度都是计算的。但根据定义,i!= j的渐变为零 - 但是tf.gradients分配了大量的内存并且运行得非常好......

一个只计算i = j的梯度的版本会很棒 - 任何想法如何到达那里?

python tensorflow diagonal gradients
2个回答
0
投票

我想我找到了一个解决方案:

my_grad = tf.gradients(tf.reduce_sum(loss),输入)

确保忽略交叉依赖关系i!= j - 这非常有效且快速。


0
投票

这是一种可能的方法:

import tensorflow as tf

x = tf.placeholder(tf.float32, [20, 50])
# Break X into its parts
x_parts = tf.unstack(x)
# Recompose
x = tf.stack(x_parts)
# Compute Y however
y = tf.reduce_sum(x, axis=1)
# Break Y into parts
y_parts = tf.unstack(y)
# Compute gradient part-wise
g_parts = [tf.gradients(y_part, x_part)[0] for x_part, y_part in zip(x_parts, y_parts)]
# Recompose gradient
g = tf.stack(g_parts)
print(g)
# Tensor("stack_1:0", shape=(20, 50), dtype=float32)

但是这至少有两个问题:

  • 它要求你使用固定大小的n(虽然不是m)。
  • 它将在图中创建O(n)节点,如果您打算使用非常大的n,这可能是一个问题。

理论上,可以使用TensorFlow while循环,但是张量数组或循环中的某些东西不会根据需要传播渐变。

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, None])
n = tf.shape(x)[0]
element_shape = x.shape[1:]
x_parts = tf.TensorArray(x.dtype, size=n, dynamic_size=False,
                         element_shape=element_shape, clear_after_read=False)
_, x_parts, _ = tf.while_loop(lambda i, x_parts, x: i < n,
                              lambda i, x_parts, x: (i + 1, x_parts.write(i, x[i]), x),
                              [tf.constant(0, n.dtype), x_parts, x])
x = x_parts.stack()
y = tf.reduce_sum(x, axis=1)
g_parts = tf.TensorArray(y.dtype, size=n, dynamic_size=False,
                         element_shape=element_shape, clear_after_read=True)
_, g_parts, _ = tf.while_loop(lambda i, g_parts, x_parts, y: i < n,
                              lambda i, g_parts, x_parts, y:
                                (i + 1, g_parts.write(i, tf.gradients(y[i], x_parts.read(i))[0]), x_parts, y),
                              [tf.constant(0, n.dtype), g_parts, x_parts, y])
# Fails due to None gradients
g = g_parts.stack()
print(g)
© www.soinside.com 2019 - 2024. All rights reserved.