Java Tensorflow + Keras等效于model.predict()

问题描述 投票:0回答:1

在python中,您可以简单地将一个numpy数组传递给predict(),以从模型中获取预测。 Java与SavedModelBundle一起使用的等效条件是什么?

Python

model = tf.keras.models.Sequential([
  # layers go here
])
model.compile(...)
model.fit(x_train, y_train)

predictions = model.predict(x_test_maxabs) # <= This line 

Java

SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
java tensorflow keras tensorflow-serving
1个回答
0
投票

TensorFlow Python自动将您的NumPy数组转换为tf.Tensor。在TensorFlow Java中,您可以直接操纵张量。

现在SavedModelBundle没有predict方法。您需要使用SessionRunner获取会话并运行它,并为其输入输入张量。

例如,基于下一代TF Java(https://github.com/tensorflow/java),您的代码最终看起来像这样(请注意,我在此处对x_test_maxabs采取了许多假设,因为您的代码示例并未清楚说明在哪里它来自):

try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
    try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
        Tensor<TFloat32> output = model.session()
            .runner()
            .feed("input_name", input)
            .fetch("output_name")
            .run()
            .expect(TFloat32.class)) {

        float prediction = output.data().getFloat();
        System.out.println("prediction = " + prediction);
    }        
}

如果不确定图形中输入/输出张量的名称是什么,可以通过查看签名定义以编程方式获得:

model.metaGraphDef().getSignatureDefMap().get("serving_default")
© www.soinside.com 2019 - 2024. All rights reserved.