使用Tensorflow进行推理使用Java进行服务

问题描述 投票:1回答:1

我们正在转换现有的Java生产代码,以使用Tensorflow服务(TFS)进行推理。我们已经重新训练了我们的模型,并使用新的SavedModel格式保存它们(不再需要冻结图形!!)。 从我读过的文档中,TFS不直接支持Java。但它确实提供了gRPC接口,并且确实提供了Java接口。

我的问题是,启动Java应用程序以使用TFS涉及哪些步骤。

[编辑:将步骤移至解决方案]

tensorflow tensorflow-serving grpc-java
1个回答
1
投票

由于文档和示例仍然有限,因此需要花费四天的时间将其拼凑在一起。 我确信有更好的方法可以做到这一点,但这是我到目前为止所发现的:

  • 我在github上克隆了tensorflow/tensorflowtensorflow/servinggoogle/protobuf repos。
  • 我使用protoc protobuf compilergrpc-java plugin编译了以下protobuf文件。我讨厌有这么多分散的.proto文件要编译,但我想要包含最小集合,并且在各种目录中有很多不需要的.proto文件可以被绘制。这是我需要的最小集合编译我们的Java应用程序: serving_repo/tensorflow_serving/apis/*.proto serving_repo/tensorflow_serving/config/model_server_config.proto serving_repo/tensorflow_serving/core/logging.proto serving_repo/tensorflow_serving/core/logging_config.proto serving_repo/tensorflow_serving/util/status.proto serving_repo/tensorflow_serving/sources/storage_path/file_system_storage_path_source.proto serving_repo/tensorflow_serving/config/log_collector_config.proto tensorflow_repo/tensorflow/core/framework/tensor.proto tensorflow_repo/tensorflow/core/framework/tensor_shape.proto tensorflow_repo/tensorflow/core/framework/types.proto tensorflow_repo/tensorflow/core/framework/resource_handle.proto tensorflow_repo/tensorflow/core/example/example.proto tensorflow_repo/tensorflow/core/protobuf/tensorflow_server.proto tensorflow_repo/tensorflow/core/example/feature.proto tensorflow_repo/tensorflow/core/protobuf/named_tensor.proto tensorflow_repo/tensorflow/core/protobuf/config.proto
  • 请注意,即使没有protoc存在,grpc-java也会编译,但是大多数关键入口点都会神秘地丢失。如果缺少PredictionServiceGrpc.java那么grpc-java没有被执行。
  • 命令行示例(为了便于阅读,插入了换行符):
$ ./protoc -I=/Users/foobar/protobuf_repo/src \
   -I=/Users/foobar/tensorflow_repo \   
   -I=/Users/foobar/tfserving_repo \  
   -plugin=protoc-gen-grpc-java=/Users/foobar/protoc-gen-grpc-java-1.20.0-osx-x86_64.exe \
   --java_out=src \
   --grpc-java_out=src \
   /Users/foobar/tfserving_repo/tensorflow_serving/apis/*.proto
  • 继gRPC documentation之后,我创建了一个频道和一个存根:
ManagedChannel mChannel;
PredictionServiceGrpc.PredictionServiceBlockingStub mBlockingstub;
mChannel = ManagedChannelBuilder.forAddress(host,port).usePlaintext().build();
mBlockingstub = PredictionServiceGrpc.newBlockingStub(mChannel);
  • 我按照几个文件拼凑了以下步骤: gRPC documents讨论存根(阻塞和异步) 这个article概述了这个过程,但是使用了Python 此示例code对于NewBuilder语法的示例至关重要。
  • Maven进口是: io.grpc:grpc-all org.tensorflow:libtensorflow org.tensorflow:proto com.google.protobuf:protobuf-java
  • 这是示例代码:
// Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();

TensorShapeProto.Dim featuresDim1  = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto     featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType).setTensorShape(featuresShape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();


// Now prepare for the inference request over gRPC to the TF Serving server
com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder().setValue(mGraphVersion).build();

Model.ModelSpec.Builder model = Model.ModelSpec
                                     .newBuilder()
                                     .setName(mGraphName)
                                     .setVersion(version);  // type = Int64Value
Model.ModelSpec     modelSpec = model.build();

Predict.PredictRequest request;
request = Predict.PredictRequest.newBuilder()
                                .setModelSpec(modelSpec)
                                .putInputs("image", featuresTensorProto)
                                .build();

Predict.PredictResponse response;

try {
    response = mBlockingstub.predict(request);
    // Refer to https://github.com/thammegowda/tensorflow-grpc-java/blob/master/src/main/java/edu/usc/irds/tensorflow/grpc/TensorflowObjectRecogniser.java

    java.util.Map<java.lang.String, org.tensorflow.framework.TensorProto> outputs = response.getOutputsOrDefault();
    for (java.util.Map.Entry<java.lang.String, org.tensorflow.framework.TensorProto> entry : outputs.entrySet()) {
        System.out.println("Response with the key: " + entry.getKey() + ", value: " + entry.getValue());
    }
} catch (StatusRuntimeException e) {
    logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus());
    success = false;
}

© www.soinside.com 2019 - 2024. All rights reserved.