我只想使用 SentenceTransformer 中预先训练的开源嵌入模型来编码纯文本。
目标是使用 swagger 作为 GUI - 放入句子并获得嵌入。
from fastapi import Depends, FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("./assets/BAAI/bge-small-en")
app = FastAPI()
class EmbeddingRequest(BaseModel):
text: str
class EmbeddingResponse(BaseModel):
embeddings: float
@app.post("/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest, model: embedding_model):
embeddings_result = model.encode(request.text)
return EmbeddingResponse(embeddings=embeddings_result)
请注意,我无法访问您的
"./assets/BAAI/bge-small-en
模型,因此我使用 all-mpnet-base-v2
代替。
也就是说,您的实施存在两个问题:
model
作为输入参数,这是不必要的。直接使用全局embedding_model
就可以了。embeddings
的返回类型是错误的。 (除非您的模型确实输出单个float
)。 embedding_model.encode
的输出是 np.ndarray
,您可以使用 embeddings_result.tolist()
将其转换为列表。以下对我有用:
from typing import List
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("all-mpnet-base-v2")
app = FastAPI()
class EmbeddingRequest(BaseModel):
text: str
class EmbeddingResponse(BaseModel):
embeddings: List[float]
@app.post("/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
embeddings_result = embedding_model.encode(request.text)
return EmbeddingResponse(embeddings=embeddings_result.tolist())