使用 ConversationalRetrievalChain、Stuff 和 Chailit UI 获取源文档并评分

问题描述 投票:0回答:1

我在尝试从此代码导出源文档和分数时遇到问题。我尝试了很多东西,但无法检索它。我最多能做的就是将我的需求传递给提示,以便法学硕士将其检索给我,但有时它只是忽略我或产生幻觉(例如:它为我提供了文本内部的源链接)。如果我将 chain_type 更改为 map_rerank 并进行一些调整,我可以从 JSON 获取源文档(基于我的 vectordb 中原始保存的元数据),但我想改用 Stuff。有没有办法从答案中获取 source_documents 和分数?或者我最初从矢量数据库获得的任何元数据?

@on_chat_start
def init():
    llm = AzureChatOpenAI(
        deployment_name=saci_constants.AZURE_OPENAI_DEPLOYMENT_NAME,
        model_name=saci_constants.AZURE_OPENAI_MODEL_NAME,
        openai_api_base=saci_constants.AZURE_OPENAI_DEPLOYMENT_ENDPOINT,
        openai_api_version=saci_constants.AZURE_OPENAI_DEPLOYMENT_VERSION,
        openai_api_key=saci_constants.AZURE_OPENAI_API_KEY,
        openai_api_type=saci_constants.AZURE_OPEN_API_TYPE,
        temperature=saci_constants.TEMPERATURE,
        streaming=True,
        callbacks=[StreamingStdOutCallbackHandler()],
    )

    embeddings = OpenAIEmbeddings(
        deployment=saci_constants.AZURE_OPENAI_ADA_EMBEDDING_DEPLOYMENT_NAME,
        model=saci_constants.AZURE_OPENAI_ADA_EMBEDDING_MODEL_NAME,
        openai_api_base=saci_constants.AZURE_OPENAI_DEPLOYMENT_ENDPOINT,
        openai_api_key=saci_constants.AZURE_OPENAI_API_KEY,
        openai_api_type=saci_constants.AZURE_OPEN_API_TYPE,
        chunk_size=saci_constants.AZURE_CHUNK_SIZE,
    )

    faiss_db = FAISS.load_local(
        saci_constants.FAISS_DATABASE_PATH,
        embeddings,
    )

    retriever = faiss_db.as_retriever()

    messages = [SystemMessagePromptTemplate.from_template(custom_prompts.SPARK)]
    messages.append(HumanMessagePromptTemplate.from_template("{question}"))
    spark_prompt = ChatPromptTemplate.from_messages(messages)

    question_generator = LLMChain(
        llm=llm,
        prompt=CONDENSE_QUESTION_PROMPT,
        verbose=True,
    )

    doc_chain = load_qa_with_sources_chain(
        llm,
        chain_type="stuff",
        prompt=spark_prompt,
        verbose=True,
    )

    memory = ConversationBufferMemory(
        llm=llm,
        memory_key="chat_history",
        return_messages=True,
        input_key="question",
        # output_key="answer",
        max_token_limit=1000,
        # k=1,
    )

    conversational_chain = ConversationalRetrievalChain(
        retriever=retriever,
        question_generator=question_generator,
        combine_docs_chain=doc_chain,
        memory=memory,
        rephrase_question=False,
        verbose=True,
        # output_key="answer",
    )

    # # Set chain as a user session variable
    cl.user_session.set("conversation_chain", conversational_chain)


@on_message
async def main(message: str):
    chat_history = []

    # Read chain from user session variable
    chain = cl.user_session.get("conversation_chain")

    # Run the chain asynchronously with an async callback
    res = chain(
        {"question": message, "chat_history": chat_history},
        callbacks=[cl.AsyncLangchainCallbackHandler()],
    )

    print("aaaaaaaa", res)

    # Send the answer and the text elements to the UI
    await cl.Message(content=f"ANSWER: {res['answer']}").send()

现在,我从 res 得到的打印是这样的:

aaaaaaaa {'question': '做 RAG 时需要支付 OpenAI 费用吗?', 'chat_history': [HumanMessage(content='做 RAG 时需要支付 OpenAI 费用吗?', extra_kwargs={}, example =False), AIMessage(content="我不确定,但根据 Towards Data Science 上的一篇文章,设置 RAG 可能是一项巨大的初始投资,涵盖集成、数据库访问,甚至可能包括许可费用。但是,没有专门提到为 OpenAI 付费。您想让我查找更多信息吗?", extra_kwargs={}, example=False)], 'answer': "我不确定,但根据一篇文章对于数据科学来说,建立 RAG 可能是一笔巨大的初始投资,涵盖集成、数据库访问,甚至可能还有许可费用。但是,没有提到具体为 OpenAI 付费。您想让我查找更多信息吗? }

python openai-api information-retrieval langchain llm
1个回答
0
投票

我明白了。我只需要在 ConversationalRetrievalChain 中添加 return_source_documents 即可:

    conversational_chain = ConversationalRetrievalChain(
        retriever=retriever,
        question_generator=question_generator,
        combine_docs_chain=doc_chain,
        memory=memory,
        rephrase_question=False,
        return_source_documents=True,
        verbose=True,
    )

然后我可以检索所有元数据:

@on_message
async def main(message: str):
    chat_history = []

    # Read chain from user session variable
    chain = cl.user_session.get("conversation_chain")

    # Run the chain asynchronously with an async callback
    res = chain({"question": message, "chat_history": chat_history})
    print(res) # res has all the metadata if you need (chunk_id, document_id, content, source, etc.)

    sources = [doc.metadata.get("source") for doc in res["source_documents"]]
    chunk = [doc.metadata.get("chunk_id") for doc in res["source_documents"]]

    # Send the answer and the text elements to the UI
    await cl.Message(
        content=f"ANSWER: {res['answer']}, SOURCES: {set(sources)}, CHUNK: {set(chunk)}"
    ).send()

分数是另一回事。我猜这是LLM本身产生的东西,是他的观点,所以有时他只是不回答,这不是一个客观的衡量标准。我以为它是基于进行信息检索时的相似度_score,但我想事实并非如此。

© www.soinside.com 2019 - 2024. All rights reserved.