问题描述
在使用 LangChain 和 Llama 模型生成 SQL 查询时,遇到了 sqlite3.OperationalError
错误。错误信息如下:
OperationalError: (sqlite3.OperationalError) near "```sql
SELECT Name
FROM MediaType
LIMIT 5;
```": syntax error
[SQL: ```sql
SELECT Name
FROM MediaType
LIMIT 5;
```]
错误发生的原因是生成的 SQL 查询包含了不必要的 Markdown 代码块标记 ```,也就是在生成SQL语句的过程中,产生了其他的不干净文本,导致 SQL 语法错误。
最终解决方案
通过修改 PromptTemplate 来生成干净的 SQL 查询,确保生成的查询不包含任何 Markdown 代码块标记或附加评论。以下是解决方案的详细步骤和代码实现:
1. 初始化环境
首先,初始化所需的环境变量和模型:
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-versatile", model_provider="groq", temperature=0)
2. 定义自定义提示模板
定义一个自定义的 PromptTemplate,用于生成干净的 SQL 查询:
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
3. 创建 SQL 查询链
创建一个 SQL 查询链,并使用自定义提示模板:
write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)
4. 构造输入数据字典
构造输入数据字典,其中包含方言、表结构、问题和行数限制:
input_data = {
"dialect": db.dialect, # 数据库方言,如 "sqlite"
"table_info": db.get_table_info(), # 表结构信息
"input": "What name of MediaType is?", # 问题
"top_k": 5 # 行数限制
}
5. 调用链生成并执行 SQL 查询
调用链生成 SQL 查询,确保生成的查询不包含 Markdown 代码块标记,然后执行查询并打印结果:
response = write_query.invoke(input_data)
query = response["query"]
# 执行 SQL 查询并打印结果
execute_query = QuerySQLDataBaseTool(db=db)
result = execute_query.invoke({"query": query})
print(result)
总结
通过修改 PromptTemplate 来生成 SQL 查询时,明确要求返回的 SQL 查询不包含任何附加评论或 Markdown 格式,确保生成的 SQL 查询是干净的、可执行的。这样可以避免由多余的标记导致的 SQL 语法错误。
最后提供完整代码:
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
load_dotenv()
# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"]) # 注意需要传递列表
print(f"\n Original table info: {table_info}")
# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
# 定义自定义提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
write_query = create_sql_query_chain(llm, db,prompt=custom_prompt)
# 构造输入数据字典,其中包含方言、表结构、问题和行数限制
input_data = {
"dialect": db.dialect, # 数据库方言,如 "sqlite"
"table_info": db.get_table_info(), # 表结构信息
"question": "What name of MediaType is?",
"top_k": 5
}
# 调用链生成 SQL 查询,返回结果为一个字典,包含键 "query"
write_query_response = write_query.invoke(input_data)
print('\n write_query result:',write_query_response)
#执行SQL语句
execute_query = QuerySQLDataBaseTool(db=db)
execute_response = execute_query.invoke(write_query_response)
print('\n execute_response result:',execute_response)
#两个动作合起来搞成链
chain = write_query | execute_query
result_chain = chain.invoke(input_data)
print('\n result_chain==',result_chain)
输出: