使用langchain和LLaMA2对大型SQL数据库进行QA

问题描述 投票:0回答:1

我正在使用为大型数据集(约 150k 行)创建的 sqlite 数据库。

代码片段:

db = SQLDatabase.from_uri("sqlite:///MLdata.sqlite")
SQLITE_PROMPT_TEXT = '''You are a SQLite expert. Given an input question, first create a 
syntactically correct SQLite query to run, then look at the results of the query and return 
the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for 
at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to 
 return the most informative data in the database.
 Never query for all columns from a table. You must query only the columns that are needed to 
 answer the question. Wrap each column name in double quotes (") to denote them as delimited 
  identifiers.
 Pay attention to use only the column names you can see in the tables below. Be careful to not 
 query for columns that do not exist. Also, pay attention to which column is in which table.

 Use the following format:

  Question: Question here
  SQLQuery: SQL Query to run
  SQLResult: Result of the SQLQuery
  Answer: Final answer here

 Only use the following tables:
 {table_info}

 Question: {input}'''

SQLITE_PROMPT = PromptTemplate(input_variables=['input', 'table_info', 'top_k'], template=SQLITE_PROMPT_TEXT)
sql_chain = SQLDatabaseChain(llm=local_llm, database=db, prompt=SQLITE_PROMPT, return_direct=False, return_intermediate_steps=False, verbose=False)

res=sql_chain("How many rows is in this db?")

回应: “该数据库的 input_table 中有 142321 行。”

第二次查询

res=sql_chain("Count rows with 'Abdominal pain', VAX_TYPE='COVID19', SEX= 'F' and HOSPITAL= 'Y' is in the input_table of this db")

回应: “input_table 中有 115 行,其中存在腹痛,VAX_TYPE 为 COVID19,性别为女性,医院为是。”

第三个查询我试图仅查找患者 ID 而不是计数。但我无法获取患者 ID。

res=sql_chain("What is the VAERS_ID with 'Abdominal pain', VAX_TYPE='COVID19', SEX= 'F' and HOSPITAL= 'Y' in this db. ")

看起来计数工作正常,但仅此而已。任何人都可以帮助我通过 langchain 和 llama2 显示类似 sqlDbchain 输出的表吗?

nlp langchain large-language-model llama
1个回答
0
投票

你可以要求它输出。

像这样:

sql_chain("Display all the users table in markdown")

# output
# {'query': 'Display all the users table in markdown',
# 'result': '| id | name |\n|----|------|\n| 1  | 2    |\n| 2  | 3    |'}
© www.soinside.com 2019 - 2024. All rights reserved.