如何对部署在 nvidia triton 上的 T5 张量模型进行推理?

问题描述 投票:0回答:1

我已经在nvidia triton服务器上部署了T5tensorrt模型,下面是config.pbtxt文件,但在使用triton客户端推断模型时遇到问题。

根据 config.pbtxt 文件,tensorrt 模型应该有 4 个输入以及解码器 ID。但是我们如何将解码器作为模型的输入发送给模型,我认为解码器是从模型输出生成的。

name: "tensorrt_model"
platform: "tensorrt_plan"
max_batch_size: 0
input [
 {
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [ -1, -1  ]
  },

{
    name: "attention_mask"
    data_type: TYPE_INT32
    dims: [-1, -1 ]
},

{
    name: "decoder_input_ids"
    data_type: TYPE_INT32
    dims: [ -1, -1]
},

{
   name: "decoder_attention_mask"
   data_type: TYPE_INT32
   dims: [ -1, -1 ]
}

]
output [
{
    name: "last_hidden_state"
    data_type: TYPE_FP32
    dims: [ -1, -1, 768 ]
  },

{
    name: "input.151"
    data_type: TYPE_FP32
    dims: [ -1, -1, -1 ]
  }

]

instance_group [
    {
        count: 1
        kind: KIND_GPU
    }
]
inference tensorrt triton huggingface
1个回答
0
投票

您在 NVIDIA Triton Client 存储库中有几个示例。但是,如果您的用例太复杂,您可能需要 Python 后端而不是 Torch 后端。

您可以按如下方式初始化客户端:

import tritonclient.http as httpclient

triton_url = None  # your triton url
triton_client = httpclient.InferenceServerClient(url=url)

考虑到您已经初始化了客户端,在 Python 中您将需要创建一个函数来生成请求,如下所示。

inputs_dtype = []  # list with inputs dtypes
inputs_name = []   # list with inputs name
outputs_name = []  # list with outputs name

def request_generator(data):
    client = httpclient
 
    inputs = [
        client.InferInput(input_name, data[i].shape,
            inputs_dtype[i]) for i, input_name in enumerate(inputs_name)
            ]
 
    for i, _input in enumerate(inputs):
        _input.set_data_from_numpy(data[i])
 
    outputs = [
        client.InferRequestedOutput(output_name) for output_name in outputs_name
        ]
 
    yield inputs, outputs

然后,您可以在循环中使用此

request_generator
来运行推理:

# assuming your data comes in a variable named data
# assuming your triton client is triton_client

data = preprocess(data)  # your preprocess function

model_name = None  # your model name
model_version = None  # your model version
       
responses = []
sent_count = 0
         
try:
    for inputs, outputs in self._request_generator(data):
        sent_count += 1
         
        responses.append(
            triton_client.infer(model_name,
                                inputs,
                                request_id=str(sent_count),
                                model_version=model_version,
                                outputs=outputs))
         
except InferenceServerException as exception:
    print("Caught an exception:", exception)

正如我所说,这只是一个关于如何执行此操作的简单说明,但它遗漏了很多实现细节。正如我所说,仓库中有很多示例

© www.soinside.com 2019 - 2024. All rights reserved.