我使用tf.Estimator
训练了一个简单的自动编码器。在训练期间,更新特定张量W
,其中W
是矩阵。训练结束后,我想检索W
并使用numpy
读取它的值。
如果我不使用tf.Estimator
,这是一个简单的任务,我会打电话给.eval()
并通过我的会话。但是,Estimator
是一个高级API,会话的初始化和使用都是在幕后完成的。
我也尝试使用Estimator.predict
和EstimatorSpec
返回W
,但它似乎不起作用。我收到以下错误:
TypeError:预期单个Tensor时的张量列表。
是否有可能在使用numpy
训练后直接检索张量的tf.Estimator
值。如果是这样,怎么样?
假设W
存储为模型中的变量,则可以使用get_variable_value
对象的Estimator
方法。见here。