这是我第一次尝试RAG应用。我正在尝试使用 LLM 进行问答。我将在下面粘贴运行良好的代码。我的问题是每次运行 python 代码时都会运行生成嵌入的代码。有没有办法只运行一次或检查嵌入文件夹是否为空,而不是运行该代码。
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_community import embeddings
from langchain_community.chat_models import ChatOllama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import OllamaEmbeddings
model_local = ChatOllama(model="codellama:7b")
loader = TextLoader("remedy.txt")
raw_doc = loader.load()
# Split the text file content into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
splitted_docs = text_splitter.split_documents(raw_doc)
# Use embedding function to store them in vector db
ollamaEmbeddings = embeddings.ollama.OllamaEmbeddings(model="nomic-embed-text")
# used chroma vector db to store the data
vectorstore = Chroma.from_documents(
documents=splitted_docs,
embedding=ollamaEmbeddings,
persist_directory="./vector/my_data",
)
# This will write the data to local
retriever = vectorstore.as_retriever()
# 4. After RAG
print("After RAG\n")
after_rag_template = """
Answer the question based only on the following context:
{context}
Question {question}?
"""
after_rag_prompt = ChatPromptTemplate.from_template(after_rag_template)
after_rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| after_rag_prompt
| model_local
| StrOutputParser()
)
print(after_rag_chain.invoke("What are Home Remedy for Common Cold?"))
每次运行此 python 脚本时,您都会提供一个持久目录,该目录会将嵌入存储在磁盘上的指定目录中。您正在传递相同的分块文档。您正在定义嵌入模型。因此,实际上,每次运行脚本时,矢量数据库都会执行相同的操作。
当您想要从磁盘加载持久数据库时,您可以实例化 Chroma 对象,指定持久目录和嵌入模型,如下所示:
# load from disk
db3 = Chroma(persist_directory="./vector/my_data", embedding_function= ollamaEmbeddings)
docs = db3.similarity_search(query)
print(docs[0].page_content)