ValueError:形状必须在assign_add()中处于相同的等级

问题描述 投票:1回答:1

我正在读取TF2中的tf.Variable in Tensorflow r2.0

import tensorflow as tf

# Create a variable.
w = tf.constant([1, 2, 3, 4], tf.float32, shape=[2, 2])

# Use the variable in the graph like any Tensor.
y = tf.matmul(w,tf.constant([7, 8, 9, 10], tf.float32, shape=[2, 2]))
v= tf.Variable(w)
# The overloaded operators are available too.
z = tf.sigmoid(w + y)
tf.shape(z)
# Assign a new value to the variable with `assign()` or a related method.
v.assign(w + 1)
v.assign_add(tf.constant([1.0, 21]))

ValueError:形状必须等于等级,但对于2和1输入形状为'AssignAddVariableOp_4'(op:'AssignAddVariableOp'):[],2

而且以下内容如何返回假?

tf.shape(v) == tf.shape(tf.constant([1.0, 21],tf.float32))

我的另一个问题是,当我们进入TF 2时,我们不应该再使用tf.Session()了吗?似乎是we should never run session.run(),但API文档使用tf.compat.v1等进行了加密。因此,为什么他们在TF2文档中使用它?

任何帮助将不胜感激。

CS

tensorflow2.0 valueerror
1个回答
0
投票

正如在误差中明确指出的那样,期望v上具有形状[2,2]的assign_add的形状[2,2]。如果您尝试赋予除试图拉伸的张量的初始形状以外的其他任何形状assign_add,则会出现错误。

以下是修改后的代码,具有预期的操作形状。

import tensorflow as tf

# Create a variable.
w = tf.constant([1, 2, 3, 4], tf.float32, shape=[2, 2])

# Use the variable in the graph like any Tensor.
y = tf.matmul(w,tf.constant([7, 8, 9, 10], tf.float32, shape=[2, 2]))
v= tf.Variable(w)
# The overloaded operators are available too.
z = tf.sigmoid(w + y)
tf.shape(z)
# Assign a new value to the variable with `assign()` or a related method.
v.assign(w + 1)
print(v)
v.assign_add(tf.constant([1, 2, 3, 4], tf.float32, shape=[2, 2]))  

v的输出:

<tf.Variable 'UnreadVariable' shape=(2, 2) dtype=float32, numpy=
array([[3., 5.],
       [7., 9.]], dtype=float32)> 

现在以下张量比较返回True

tf.shape(v) == tf.shape(tf.constant([1.0, 21],tf.float32)) 

<tf.Tensor: shape=(2,), dtype=bool, numpy=array([ True, True])>

涉及您的tf.Session()问题,在TensorFlow 2.0中,默认情况下仍启用急切执行,如果您需要禁用急切执行并可以使用如下所示的tf.Session

import tensorflow as tf

tf.compat.v1.disable_eager_execution()

hello = tf.constant('Hello, TensorFlow!')

sess = tf.compat.v1.Session()

print(sess.run(hello)) 
© www.soinside.com 2019 - 2024. All rights reserved.