从API为Langchain中的QA链定制LLM

问题描述 投票:0回答:1

目前,我想构建用于生产的 RAG 聊天机器人。 我已经有了 LLM API,我想创建一个自定义 LLM,然后在 RetrievalQA.from_chain_type 函数中使用它。 我不知道Langchain是否支持我的情况。

我在 reddit 上读到了这个主题:https://www.reddit.com/r/LangChain/comments/17v1rhv/integrating_llm_rest_api_into_a_langchain/ 在langchain文档中:https://python.langchain.com/docs/modules/model_io/llms/custom_llm

但是当我将自定义LLM应用到qa_chain时,这仍然不起作用。 以下是我的代码,希望得到您的支持,对不起我的语言,英语不是我的母语。

from pydantic import Extra
import requests
from typing import Any, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

class LlamaLLM(LLM):
    llm_url = 'https:/myhost/llama/api'

    class Config:
        extra = Extra.forbid

    @property
    def _llm_type(self) -> str:
        return "Llama2 7B"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")

        payload = {
            "inputs": prompt,
            "parameters": {"max_new_tokens": 100},
            "token": "abcdfejkwehr"
        }

        headers = {"Content-Type": "application/json"}

        response = requests.post(self.llm_url, json=payload, headers=headers, verify=False)
        response.raise_for_status()

        # print("API Response:", response.json())

        return response.json()['generated_text']  # get the response from the API

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"llmUrl": self.llm_url}

llm = LlamaLLM()
#Testing
prompt = "[INST] Question: Who is Albert Einstein? \n Answer: [/INST]"
result = llm._call(prompt)
print(result)

Albert Einstein (1879-1955) was a German-born theoretical physicist who is widely regarded as one of the most influential scientists of the 20th century. He is best known for his theory of relativity, which revolutionized our understanding of space and time, and his famous equation E=mc².
# Build prompt
from langchain.prompts import PromptTemplate
template = """[INST] <<SYS>>

Answer the question base on the context below.

<</SYS>>

Context: {context}
Question: {question}
Answer:
[/INST]"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,)

# Run chain
from langchain.chains import RetrievalQA

qa_chain = RetrievalQA.from_chain_type(llm,
                                       verbose=True,
                                       # retriever=vectordb.as_retriever(),
                                       retriever=custom_retriever,
                                       return_source_documents=True,
                                       chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
question = "Is probability a class topic?"
result = qa_chain({"query": question})
result["result"]

Encountered some errors. Please recheck your request!

这是来自 dosubot github 的反馈,但我不太明白

首先,在 LlamaLLM 类中,_llm_type 属性应返回一个与 get_type_to_cls_dict 函数中自定义 LLM 名称匹配的字符串。在你的情况下,它应该是“LlamaLLM”而不是“Llama2 7B”。这是因为 get_type_to_cls_dict 函数使用此字符串导入正确的 LLM。

其次,您需要将自定义 LLM 添加到 get_type_to_cls_dict 函数中。 RetrievalQA.from_chain_type 函数使用此函数导入正确的 LLM。具体方法如下:

def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
    return {
        ...
        "LlamaLLM": _import_llama_llm,
        ...
    }

在上面的代码中,_import_llama_llm 是一个导入 LlamaLLM 类并返回它的实例的函数。您需要创建此函数并将其添加到定义 get_type_to_cls_dict 函数的同一文件中。

最后,当您调用 RetrievalQA.from_chain_type 函数时,您应该传递“LlamaLLM”作为 chain_type 参数。这是因为 from_chain_type 函数使用此字符串通过 get_type_to_cls_dict 函数导入正确的 LLM。

qa_chain = RetrievalQA.from_chain_type("LlamaLLM",
                                       verbose=True,
                                       # retriever=vectordb.as_retriever(),
                                       retriever=custom_retriever,
                                       return_source_documents=True,
                                       chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
python python-requests chatbot langchain large-language-model
1个回答
0
投票

我使用你的代码,没有遇到任何问题。也许 custom_retriever 失败了。如果您可以提供更多相关信息吗?

© www.soinside.com 2019 - 2024. All rights reserved.