递归分配给Tensorflow中的可变切片

问题描述 投票:0回答:1

我想将值递归地分配给Tensorflow(1.15)变量中的切片。

为了说明,这有效:

    def test_loss():

        m = tf.Variable(1)
        n = 3

        A = tf.Variable(tf.zeros([10., 20., 30.]))
        B = tf.Variable(tf.ones([10., 20., 30.]))
        A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])

        return 1

    test_loss()
    Out: 1

然后我尝试:

    def test_loss():

        m = tf.Variable(1)
        #n = 3

        A = tf.Variable(tf.zeros([10., 20., 30.]))
        B = tf.Variable(tf.ones([10., 20., 30.]))

        for n in range(5):
            A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])

        return 1

    test_loss()

但是这会返回错误消息:

    ---> 10         A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
    ...
    ValueError: Sliced assignment is only supported for variables

我知道'assign'返回的不是'Variable',因此在下一个循环中传递'A'将不再找到“变量”。

然后我尝试:

    def test_loss():

        m = tf.Variable(1)
        #n = 3

        A = tf.Variable(tf.zeros([10., 20., 30.]))
        B = tf.Variable(tf.ones([10., 20., 30.]))

        for n in range(5):
            A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))

        return 1

    test_loss()

然后我得到:

    InvalidArgumentError: Input 'ref' passed float expected ref type while building NodeDef...

关于我可以将值递归分配给Tensorflow变量片的任何想法吗?

tensorflow variables slice assign
1个回答
0
投票

这里有一些使用tf.Variableassign()的见解。

第一个失败的解决方案

for n in range(5):
            A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])

[当您执行A.assign(B)时,实际上返回一个张量(即不是tf.Variable)。因此,它适用于第一次迭代。从下一次迭代开始,您尝试将值分配给tf.Tensor,这是不允许的。

失败的第二个解决方案

for n in range(5):
            A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))

这再次是一个非常糟糕的主意,因为您正在循环中创建变量。这样做足够,您将耗尽内存。但这甚至无法运行,因为您最终陷入了时髦的僵局。您正在尝试创建具有某个张量的变量,该张量将在执行图时进行计算。要执行图形,您需要变量。

正确的方法

我可以想到的最好方法是让test_loss返回更新操作,然后将n设为TensorFlow占位符。在运行会话的每次迭代中,您都将一个值传递给n(这是当前迭代)。

def test_loss(n):

        m = tf.Variable(1)
        #n = 3

        A = tf.Variable(tf.zeros([10., 20., 30.]))
        B = tf.Variable(tf.ones([10., 20., 30.]))

        update = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])        

        return update

with tf.Session() as sess:

    tf_n = tf.placeholder(shape=None, dtype=tf.int32, name='n')
    update_op = test_loss(tf_n)
    print(type(update_op))
    tf.global_variables_initializer().run()
    for n in range(5):
      print(1)
      #print(sess.run(update_op, feed_dict={tf_n: n}))

© www.soinside.com 2019 - 2024. All rights reserved.