CFP中仅限gRPC的Tensorflow服务客户端

问题描述 投票:0回答:1

似乎有一些信息用于在Python中创建一个仅用于gRPC的客户端(甚至还有一些其他语言),并且我能够成功地获得一个只使用Python中的gRPC的工作客户端,这对我们的实现起作用。

我似乎无法找到的是某人用C ++成功编写了客户端的情况。

任务的约束如下:

  1. 构建系统不能是bazel,因为最终的应用程序已经拥有自己的构建系统。
  2. 客户端不能包含Tensorflow(需要bazel在C ++中构建)。
  3. 应用程序应该使用gRPC而不是HTTP调用来提高速度。
  4. 理想情况下,应用程序不会调用Python或以其他方式执行shell命令。

鉴于上述约束,假设我提取并生成了gRPC存根,这是否可能?如果是这样,可以提供一个例子吗?

c++ tensorflow client grpc tensorflow-serving
1个回答
0
投票

事实证明,如果您已经在Python中完成它,这不是什么新鲜事。假设模型已命名为“预测”并且模型的输入称为“输入”,则以下是Python代码:

import logging
import grpc
from grpc import RpcError

from types_pb2 import DT_FLOAT
from tensor_pb2 import TensorProto
from tensor_shape_pb2 import TensorShapeProto
from predict_pb2 import PredictRequest
from prediction_service_pb2_grpc import PredictionServiceStub


class ModelClient:
    """Client Facade to work with a Tensorflow Serving gRPC API"""
    host = None
    port = None
    chan = None
    stub = None

    logger = logging.getLogger(__name__)

    def __init__(self, name, dims, dtype=DT_FLOAT, version=1):
        self.model = name
        self.dims = [TensorShapeProto.Dim(size=dim) for dim in dims]
        self.dtype = dtype
        self.version = version

    @property
    def hostport(self):
        """A host:port string representation"""
        return f"{self.host}:{self.port}"

    def connect(self, host='localhost', port=8500):
        """Connect to the gRPC server and initialize prediction stub"""
        self.host = host
        self.port = int(port)

        self.logger.info(f"Connecting to {self.hostport}...")
        self.chan = grpc.insecure_channel(self.hostport)

        self.logger.info("Initializing prediction gRPC stub.")
        self.stub = PredictionServiceStub(self.chan)

    def tensor_proto_from_measurement(self, measurement):
        """Pass in a measurement and return a tensor_proto protobuf object"""
        self.logger.info("Assembling measurement tensor.")
        return TensorProto(
            dtype=self.dtype,
            tensor_shape=TensorShapeProto(dim=self.dims),
            string_val=[bytes(measurement)]
        )

    def predict(self, measurement, timeout=10):
        """Execute prediction against TF Serving service"""
        if self.host is None or self.port is None \
                or self.chan is None or self.stub is None:
            self.connect()

        self.logger.info("Creating request.")
        request = PredictRequest()
        request.model_spec.name = self.model

        if self.version > 0:
            request.model_spec.version.value = self.version

        request.inputs['inputs'].CopyFrom(
            self.tensor_proto_from_measurement(measurement))

        self.logger.info("Attempting to predict against TF Serving API.")
        try:
            return self.stub.Predict(request, timeout=timeout)
        except RpcError as err:
            self.logger.error(err)
            self.logger.error('Predict failed.')
            return None

以下是一个工作(粗略)C ++翻译:

#include <iostream>
#include <memory>
#include <string>

#include <grpcpp/grpcpp.h>

#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "google/protobuf/map.h"

#include "types.grpc.pb.h"
#include "tensor.grpc.pb.h"
#include "tensor_shape.grpc.pb.h"
#include "predict.grpc.pb.h"
#include "prediction_service.grpc.pb.h"

using grpc::Channel;
using grpc::ClientContext;
using grpc::Status;

using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
using tensorflow::serving::PredictRequest;
using tensorflow::serving::PredictResponse;
using tensorflow::serving::PredictionService;

typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap;

class ServingClient {
 public:
  ServingClient(std::shared_ptr<Channel> channel)
      : stub_(PredictionService::NewStub(channel)) {}

  // Assembles the client's payload, sends it and presents the response back
  // from the server.
  std::string callPredict(const std::string& model_name,
                          const float& measurement) {

    // Data we are sending to the server.
    PredictRequest request;
    request.mutable_model_spec()->set_name(model_name);

    // Container for the data we expect from the server.
    PredictResponse response;

    // Context for the client. It could be used to convey extra information to
    // the server and/or tweak certain RPC behaviors.
    ClientContext context;

    google::protobuf::Map<std::string, tensorflow::TensorProto>& inputs =
      *request.mutable_inputs();

    tensorflow::TensorProto proto;
    proto.set_dtype(tensorflow::DataType::DT_FLOAT);
    proto.add_float_val(measurement);

    proto.mutable_tensor_shape()->add_dim()->set_size(5);
    proto.mutable_tensor_shape()->add_dim()->set_size(8);
    proto.mutable_tensor_shape()->add_dim()->set_size(105);

    inputs["inputs"] = proto;

    // The actual RPC.
    Status status = stub_->Predict(&context, request, &response);

    // Act upon its status.
    if (status.ok()) {
      std::cout << "call predict ok" << std::endl;
      std::cout << "outputs size is " << response.outputs_size() << std::endl;

      OutMap& map_outputs = *response.mutable_outputs();
      OutMap::iterator iter;
      int output_index = 0;

      for (iter = map_outputs.begin(); iter != map_outputs.end(); ++iter) {
        tensorflow::TensorProto& result_tensor_proto = iter->second;
        std::string section = iter->first;
        std::cout << std::endl << section << ":" << std::endl;

        if ("classes" == section) {
          int titer;
          for (titer = 0; titer != result_tensor_proto.int64_val_size(); ++titer) {
            std::cout << result_tensor_proto.int64_val(titer) << ", ";
          }
        } else if ("scores" == section) {
          int titer;
          for (titer = 0; titer != result_tensor_proto.float_val_size(); ++titer) {
            std::cout << result_tensor_proto.float_val(titer) << ", ";
          }
        }
        std::cout << std::endl;
        ++output_index;
      }
      return "Done.";
    } else {
      std::cout << "gRPC call return code: " << status.error_code() << ": "
                << status.error_message() << std::endl;
      return "RPC failed";
    }
  }

 private:
  std::unique_ptr<PredictionService::Stub> stub_;
};

请注意,此处的尺寸已在代码中指定而不是传入。

鉴于上述类,执行可以如下:

int main(int argc, char** argv) {
  float measurement[5*8*105] = { ... data ... };

  ServingClient sclient(grpc::CreateChannel(
      "localhost:8500", grpc::InsecureChannelCredentials()));
  std::string model("predict");
  std::string reply = sclient.callPredict(model, *measurement);
  std::cout << "Predict received: " << reply << std::endl;

  return 0;
}

使用的Makefile是从gRPC C ++示例中借用的,其中PROTOS_PATH变量设置相对于Makefile和以下构建目标(假设C ++应用程序名为predict.cc):

predict: types.pb.o types.grpc.pb.o tensor_shape.pb.o tensor_shape.grpc.pb.o resource_handle.pb.o resource_handle.grpc.pb.o model.pb.o model.grpc.pb.o tensor.pb.o tensor.grpc.pb.o predict.pb.o predict.grpc.pb.o prediction_service.pb.o prediction_service.grpc.pb.o predict.o
    $(CXX) $^ $(LDFLAGS) -o $@
© www.soinside.com 2019 - 2024. All rights reserved.