张量流中的tf.GraphKeys.GLOBAL_VARIABLES和tf.GraphKeys.TRAINABLE_VARIABLES有什么区别?

问题描述 投票:0回答:1

来自https://www.tensorflow.org/api_docs/python/tf/GraphKeys

GLOBAL_VARIABLES:共享的Variable对象的默认集合跨分布式环境(模型变量是这些变量的子集)。有关更多详细信息,请参见tf.compat.v1.global_variables。通常,所有TRAINABLE_VARIABLES变量将位于MODEL_VARIABLES中,所有MODEL_VARIABLES变量将位于GLOBAL_VARIABLES

TRAINABLE_VARIABLES:将是可变对象的子集由优化人员培训。有关更多信息,请参见tf.compat.v1.trainable_variables详细信息

据我所知,TRAINABLE_VARIABLESGLOBAL_VARIABLES的子集,所以GLOBAL_VARIABLES还包含什么?

对于这个简单的示例语句Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES也不成立:

IMAGE_HEIGHT = 5
IMAGE_WIDTH = 5
with tf.Graph().as_default():
    with tf.variable_scope('my_scope', reuse=tf.AUTO_REUSE):
        x_ph = tf.placeholder(
                dtype=tf.float32,
                shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
                name='input'
            )

        x_tf = tf.layers.conv2d(x_ph, 32, 1, 1, padding='valid')

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        x_np = np.random.rand(1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)

        out_np = sess.run(x_tf, {x_ph:x_np})

        print('out_np.shape', out_np.shape)

        print('-'*60)
        global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        print('len(global_vars)', len(global_vars))
        print('global_vars params:', sum([np.prod(var.shape) for var in global_vars]))
        print(global_vars)

        print('-'*60)
        model_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
        print('len(model_vars)', len(model_vars))
        print('model_vars params:', sum([np.prod(var.shape) for var in model_vars]))
        print(model_vars)

        print('-'*60)
        trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        print('len(trainable_vars)', len(trainable_vars))
        print('trainable_vars params:', sum([np.prod(var.shape) for var in trainable_vars]))
        print(trainable_vars)

输出:

out_np.shape (1, 5, 5, 32)
------------------------------------------------------------
len(global_vars) 2
global_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]
------------------------------------------------------------
len(model_vars) 0
model_vars params: 0
[]
------------------------------------------------------------
len(trainable_vars) 2
trainable_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]

所以问题是:

  1. 为什么Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES在此示例中不成立。

  2. GLOBAL_VARIABLES除了TRAINABLE_VARIABLES还包含哪些其他变量?是TRAINABLE_VARIABLES始终是GLOBAL_VARIABLES的子集,还是它们只能部分相交?

python tensorflow
1个回答
0
投票

注意:

© www.soinside.com 2019 - 2024. All rights reserved.