使用BERT(TF 1.x)保存的模型进行推理

问题描述 投票:1回答:1

我被困在一行代码上,结果整个周末都停滞在一个项目上。

我正在一个使用BERT进行句子分类的项目。我已经成功地训练了模型,并且可以使用run_classifier.py中的示例代码来测试结果。

我可以使用此示例代码来导出模型(该代码已重复发布,因此我认为它适合该模型):

def export(self):
  def serving_input_fn():
    label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, self.max_seq_length], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, self.max_seq_length], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, self.max_seq_length], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids, 'input_ids': input_ids,
        'input_mask': input_mask, 'segment_ids': segment_ids})()
    return input_fn
  self.estimator._export_to_tpu = False
  self.estimator.export_savedmodel(self.output_dir, serving_input_fn)

我还可以加载导出的估计量(导出功能将导出的模型保存到带有时间戳的子目录中:]

predict_fn = predictor.from_saved_model(self.output_dir + timestamp_number)

但是,对于我的一生,我无法弄清为预测输入提供的预测值。这是我目前最好的代码:

def predict(self):
  input = 'Test input'
  guid = 'predict-0'
  text_a = tokenization.convert_to_unicode(input)
  label = self.label_list[0]
  examples = [InputExample(guid=guid, text_a=text_a, text_b=None, label=label)]
  features = convert_examples_to_features(examples, self.label_list,
    self.max_seq_length, self.tokenizer)
  predict_input_fn = input_fn_builder(features, self.max_seq_length, False)
  predict_fn = predictor.from_saved_model(self.output_dir + timestamp_number)
  result = predict_fn(predict_input_fn)       # this generates an error
  print(result)

我提供给predict_fn似乎无关紧要:示例数组,功能数组,predict_input_fn函数。显然,predict_fn需要某种类型的字典-但是我尝试过的每件事都会由于张量不匹配或其他通常表示错误的输入而产生异常:]输入错误。

我假设from_saved_model函数需要与模型测试函数相同的输入-显然不是这种情况。

似乎[[很多]]人已经问了这个问题-“如何使用导出的BERT TensorFlow模型进行推理?” -并且没有答案:Thread #1

Thread #2

Thread #3

Thread #4

有帮助吗?预先感谢。

我被困在一行代码上,结果整个周末都被一个项目拖延了。我正在一个使用BERT进行句子分类的项目。我已经成功训练了模型,...

tensorflow tensorflow-serving tensorflow-estimator
1个回答
0
投票
谢谢你的这篇文章。您的serving_input_fn是我所缺少的!您需要更改predict函数以直接提供功能dict,而不是使用predict_input_fn:
© www.soinside.com 2019 - 2024. All rights reserved.