我正在寻找一个可以在 TFLite 中使用的 en-zh 翻译模型,我在 Huggingface 上找到了一个:https://huggingface.co/Helsinki-NLP/opus-mt-en-zh
我已通过以下代码将模型转换为.tflite:
import tensorflow as tf
from transformers import TFMarianMTModel, AutoTokenizer
print("loading model...")
model_name = 'Helsinki-NLP/opus-mt-en-zh'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFMarianMTModel.from_pretrained(model_name, from_pt=True)
converter = tf.lite.TFLiteConverter.from_keras_model(model);
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert();
with open("./out/tf_model.tflite", 'wb') as o_:
o_.write(tflite_model)
但是当我尝试使用它进行干扰时,我遇到了一些问题:
from transformers import AutoTokenizer
import tensorflow as tf
import numpy as np
model_name = 'Helsinki-NLP/opus-mt-en-zh'
tokenizer = AutoTokenizer.from_pretrained(model_name)
result = tokenizer(">>cmn_Hans<< hello world", return_tensors="tf", padding=True)
print("tokenize result: ", result, '\n')
interpreter = tf.lite.Interpreter(model_path="./out/tf_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_ids = result["input_ids"]
interpreter.set_tensor(input_details[2]['index'], input_ids)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print("output_data:", output_data)
错误是:
Traceback (most recent call last):
File "lite.py", line 48, in <module>
interpreter.set_tensor(input_details[2]['index'], input_ids)
File "/Users/xuanyue/venv/lib/python3.8/site-packages/tensorflow/lite/python/interpreter.py", line 607, in set_tensor
self._interpreter.SetTensor(tensor_index, value)
ValueError: Cannot set tensor: Dimension mismatch. Got 1 but expected 3 for dimension 0 of input 2.
我尝试调整形状大小,但我不知道该怎么做,有什么建议吗?
first_token = input_ids[0, 0]
input_ids_reshaped = np.array([[first_token]])
interpreter.set_tensor(input_details[3]['index'], input_ids_reshaped)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print("output_data:", output_data)