来自https://www.tensorflow.org/api_docs/python/tf/GraphKeys
GLOBAL_VARIABLES:共享的Variable对象的默认集合跨分布式环境(模型变量是这些变量的子集)。有关更多详细信息,请参见tf.compat.v1.global_variables。通常,所有TRAINABLE_VARIABLES变量将位于MODEL_VARIABLES中,所有MODEL_VARIABLES变量将位于GLOBAL_VARIABLES
TRAINABLE_VARIABLES:将是可变对象的子集由优化人员培训。有关更多信息,请参见tf.compat.v1.trainable_variables详细信息
据我所知,TRAINABLE_VARIABLES
是GLOBAL_VARIABLES
的子集,所以GLOBAL_VARIABLES
还包含什么?
对于这个简单的示例语句Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES
也不成立:
IMAGE_HEIGHT = 5
IMAGE_WIDTH = 5
with tf.Graph().as_default():
with tf.variable_scope('my_scope', reuse=tf.AUTO_REUSE):
x_ph = tf.placeholder(
dtype=tf.float32,
shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
name='input'
)
x_tf = tf.layers.conv2d(x_ph, 32, 1, 1, padding='valid')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x_np = np.random.rand(1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)
out_np = sess.run(x_tf, {x_ph:x_np})
print('out_np.shape', out_np.shape)
print('-'*60)
global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print('len(global_vars)', len(global_vars))
print('global_vars params:', sum([np.prod(var.shape) for var in global_vars]))
print(global_vars)
print('-'*60)
model_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)
print('len(model_vars)', len(model_vars))
print('model_vars params:', sum([np.prod(var.shape) for var in model_vars]))
print(model_vars)
print('-'*60)
trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print('len(trainable_vars)', len(trainable_vars))
print('trainable_vars params:', sum([np.prod(var.shape) for var in trainable_vars]))
print(trainable_vars)
输出:
out_np.shape (1, 5, 5, 32)
------------------------------------------------------------
len(global_vars) 2
global_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]
------------------------------------------------------------
len(model_vars) 0
model_vars params: 0
[]
------------------------------------------------------------
len(trainable_vars) 2
trainable_vars params: 128
[<tf.Variable 'my_scope/conv2d/kernel:0' shape=(1, 1, 3, 32) dtype=float32_ref>, <tf.Variable 'my_scope/conv2d/bias:0' shape=(32,) dtype=float32_ref>]
所以问题是:
为什么Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES
在此示例中不成立。
GLOBAL_VARIABLES
除了TRAINABLE_VARIABLES
还包含哪些其他变量?是TRAINABLE_VARIABLES
始终是GLOBAL_VARIABLES
的子集,还是它们只能部分相交?
注意: