使用TensorFlow优化器优化涉及tf.keras的“model.predict()”的函数?

问题描述 投票:2回答:1

我用tf.keras构建了一个完全连接的ANN,“my_model”。然后,我试图使用来自TensorFlow的Adam优化器最小化函数f(x) = my_model.predict(x) - 0.5 + g(x)。我尝试了下面的代码:

x = tf.get_variable('x', initializer = np.array([1.5, 2.6]))
f = my_model.predict(x) - 0.5 + g(x)
optimizer = tf.train.AdamOptimizer(learning_rate=.001).minimize(f) 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(50):
        print(sess.run([x,f]))
        sess.run(optimizer)

但是,执行my_model.predict(x)时出现以下错误:

如果您的数据采用符号张量的形式,则应指定steps参数(而不是batch_size参数)

我理解错误是什么,但我无法弄清楚如何使my_model.predict(x)在符号张量存在的情况下工作。如果从函数my_model.predict(x)中删除f(x),则代码运行时没有任何错误。

我检查了以下linklink,其中TensorFlow优化器用于最小化任意函数,但我认为我的问题是使用底层keras的model.predict()函数。我感谢任何帮助。提前致谢!

python tensorflow optimization
1个回答
0
投票

我找到了答案!

基本上,我试图优化一个涉及训练有素的人工神经网络的功能,输入变量到人工神经网络。所以,我想要的只是知道如何调用my_model并将其放入f(x)。在这里挖掘一下Keras文档:https://keras.io/getting-started/functional-api-guide/,我发现所有Keras模型都可以像模型层一样调用!从链接中引用信息,

..你可以通过在张量上调用它来将任何模型视为一个层。请注意,通过调用模型,您不仅可以重用模型的体系结构,还可以重用其权重。

同时,model.predict(x)部分期望x是numpy数组或评估张量,并且不将张量流量变量作为输入(https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict)。

所以下面的代码工作:

## initializations
sess = tf.InteractiveSession()
x_init_value = np.array([1.5, 2.6])
x_placeholder =  tf.placeholder(tf.float32)
x_var = tf.Variable(x_init_value, dtype=tf.float32)

# Check calling my_model
assign_step = tf.assign(x_var, x_placeholder)
sess.run(assign_step, feed_dict={x_placeholder: x_init_value})
model_output = my_model(x_var) # This simple step is all I wanted!
sess.run(model_output) # This outputs my_model's predicted value for input x_init_value

# Now, define the objective function that has to be minimized
f = my_model(x_var) - 0.5 + g(x_var) # g(x_var) is some function of x_var

# Define the optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=.001).minimize(f) 

# Run the optimization steps
for i in range(50): # for 50 steps
    _,loss = optimizer.minimize(f, var_list=[x_var])
    print("step: ", i+1, ", loss: ", loss, ", X: ", x_var.eval()))    
© www.soinside.com 2019 - 2024. All rights reserved.