问题陈述
我希望为我的 LLM 创建一个具有隔离用户会话的 FastAPI 端点,该端点使用 ConversationBufferMemory。该内存将作为人工智能和用户之间对话的上下文。目前,它已与AI和所有用户共享。我希望隔离每个用户的内存。
我有下面
Langchain
核心库的基本实现。
样板代码
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationChain
memory = ConversationBufferMemory(memory_key="chat_history", k=12)
async def interview_function(input_text):
prompt = PromptTemplate(
input_variables=["chat_history", "input"], template=interview_template)
chat_model = ChatOpenAI(model_name="gpt-4-1106-preview", temperature = 0,
openai_api_key = OPENAI_API_KEY, max_tokens=1000)
llm_chain = ConversationChain(
llm=chat_model,
prompt=prompt,
verbose=True,
memory=memory,
)
return llm_chain.predict(input=input_text)
我通过子类化 ConversationChain 取得了进展,目的是从单独的数据存储(例如 SQL 表)传递与用户的唯一 id 相关的自定义内存键,我用它来引用与我的 LLM 交互的各个用户。
子类化进度
def create_extended_conversation_chain(keys: List[str]):
class ExtendedConversationChain(ConversationChain):
input_key: List[str] = Field(keys)
@property
def input_keys(self) -> List[str]:
"""Override the input_keys property to return the new input_key list."""
return self.input_key
@root_validator(allow_reuse=True)
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
memory_keys = values["memory"].memory_variables
input_key = values["input_key"]
prompt_variables = values["prompt"].input_variables
expected_keys = memory_keys + input_key
if set(expected_keys) != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but got {memory_keys} as inputs from "
f"memory, and {input_key} as the normal input keys."
)
return values
return ExtendedConversationChain
但是,我一直在创建这个自定义记忆键。 我的内存键在实例化时定义后似乎无法访问,就像我在样板代码部分中所做的那样。
有 Langchain 特定的解决方案吗?还是我需要创建自己的缓存并让我的 LLM 与其交互?
从技术上讲,字典暂时解决了我注意到的问题。
user_memory_dict = {}
async def interview_function(input_text: str, user_id: int):
if user_id not in user_memory_dict:
user_memory_dict[user_id] = ConversationBufferMemory(
memory_key="chat_history", k=12
)
memory = user_memory_dict[user_id]
prompt = PromptTemplate(
input_variables=["chat_history", "input"], template=interview_template
)
chat_model = ChatOpenAI(model_name="gpt-4-1106-preview", temperature=0,
openai_api_key=OPENAI_API_KEY, max_tokens=1000)
llm_chain = ConversationChain(
llm=chat_model,
prompt=prompt,
verbose=True,
memory=memory,
)
return llm_chain.predict(input=input_text)
出于可扩展性的目的,如果不存储对话,则此“字典”将被外部缓存存储取代,例如 Redis 或基于会话的存储。