张量流中的条件打印节点

问题描述 投票:2回答:1

我正在寻找一种在张量流中有条件打印节点的方法,使用下面的代码示例行,其中每10个循环计数,它应该在控制台中打印一些内容。但这对我不起作用。有人可以建议吗?

谢谢,Hamidreza,

epsilon = tf.cond(tf.constant(counter % 10 == 0, dtype=tf.bool), true_fn=lambda:tf.Print(epsilon, [counter, epsilon], 'batch: ', summarize=10), false_fn=lambda:epsilon)
tensorflow printing conditional
1个回答
0
投票

我遇到了类似的问题,并通过以下难看的解决方案解决了它。

由于解密,我使用的是tf.print而不是'tf.Print':

由于true_func和false_func应该返回相同的类型和形状,所以我返回了无意义的const。

def true_func():
    printfunc = tf.print(TENSOR_TO_PRINT,summarize=-1)
    with tf.control_dependencies([printfunc]):
        return tf.constant(1) 

def false_func():
    return tf.constant(1)

shouldprint = tf.cond(YOUR_CONDITION ,true_fn = true_func ,false_fn=false_func)

然后您可以运行shouldprint操作或将其添加为其他操作的依赖项。

© www.soinside.com 2019 - 2024. All rights reserved.