将PyTorch BERT模型转换为TFLite

问题描述 投票:0回答:1

我拥有使用预训练的bert模型构建的用于语义搜索引擎的代码。我想将此模型转换为tflite以便将其部署到Google mlkit。我想知道如何转换它。我想知道是否有可能将其转换为tflite。可能是因为它在官方tensorflow网站上提到过:https://www.tensorflow.org/lite/convert。但我不知道从哪里开始

代码:


from sentence_transformers import SentenceTransformer

# Load the BERT model. Various models trained on Natural Language Inference (NLI) https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/nli-models.md and 
# Semantic Textual Similarity are available https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/sts-models.md

model = SentenceTransformer('bert-base-nli-mean-tokens')

# A corpus is a list with documents split by sentences.

sentences = ['Absence of sanity', 
             'Lack of saneness',
             'A man is eating food.',
             'A man is eating a piece of bread.',
             'The girl is carrying a baby.',
             'A man is riding a horse.',
             'A woman is playing violin.',
             'Two men pushed carts through the woods.',
             'A man is riding a white horse on an enclosed ground.',
             'A monkey is playing drums.',
             'A cheetah is running behind its prey.']

# Each sentence is encoded as a 1-D vector with 78 columns
sentence_embeddings = model.encode(sentences)

print('Sample BERT embedding vector - length', len(sentence_embeddings[0]))

print('Sample BERT embedding vector - note includes negative values', sentence_embeddings[0])

#@title Sematic Search Form

# code adapted from https://github.com/UKPLab/sentence-transformers/blob/master/examples/application_semantic_search.py

query = 'Nobody has sane thoughts' #@param {type: 'string'}

queries = [query]
query_embeddings = model.encode(queries)

# Find the closest 3 sentences of the corpus for each query sentence based on cosine similarity
number_top_matches = 3 #@param {type: "number"}

print("Semantic Search Results")

for query, query_embedding in zip(queries, query_embeddings):
    distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, "cosine")[0]

    results = zip(range(len(distances)), distances)
    results = sorted(results, key=lambda x: x[1])

    print("\n\n======================\n\n")
    print("Query:", query)
    print("\nTop 5 most similar sentences in corpus:")

    for idx, distance in results[0:number_top_matches]:
        print(sentences[idx].strip(), "(Cosine Score: %.4f)" % (1-distance))
python tensorflow tensorflow-lite bert tf-lite
1个回答
0
投票

首先,您需要在TensorFlow中建立模型,所使用的包是用PyTorch编写的。 Huggingface的Transformers具有TensorFlow模型,您可以从中开始。此外,他们也有TFLite-ready models for Android。

通常,您首先拥有一个TensorFlow模型。将其保存为SavedModel格式:

tf.saved_model.save(pretrained_model, "/tmp/pretrained-bert/1/")

您可以在此上运行转换器。

© www.soinside.com 2019 - 2024. All rights reserved.