1. 开启调试模式
from langchain import debug
debug = True # 启用调试模式
说明:
这里从 langchain
库中导入了一个名为 debug
的变量(或模块),然后将它设置为 True
。这通常用来启用调试模式,方便开发者在程序运行时看到更多内部日志和详细信息,便于排查问题。
举例:
如果你的程序出现异常或输出结果不符合预期,启用调试模式可以在控制台中打印出详细的调试日志,帮助你定位问题所在。
2. 导入其他必需模块
import getpass
import os
from langchain.chat_models import init_chat_model
# from langchain_core.prompts import PromptTemplate
# from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
说明:
getpass
:用于安全地从用户处获取密码或敏感信息,不会在屏幕上回显输入内容。os
:用于操作系统相关的功能,例如环境变量读取和设置。init_chat_model
:从langchain.chat_models
导入,用来初始化聊天模型。load_dotenv
:用于加载项目根目录下的.env
文件,自动设置环境变量。here
:来自pyprojroot
库,用来确定项目根目录的路径。create_sql_query_chain
:用于根据自然语言问题生成SQL查询的链(Chain),注意这个工具用不同的模型来生成SQL语句时可能会返回不干净的SQL语句。SQLDatabase
:用于创建数据库对象,方便后续查询。StrOutputParser
:将模型输出解析成字符串。ChatPromptTemplate
:用于创建聊天风格的提示模板。
举例:
例如,通过 load_dotenv()
加载环境变量,你可以将 API 密钥放在 .env
文件中,避免硬编码在代码中。
3. 加载环境变量
load_dotenv()
说明:
这行代码会从当前目录(或项目根目录)下加载一个名为 .env
的文件,并将里面的变量设置为环境变量。这样我们就可以在代码中通过 os.environ
来访问这些变量。
举例:
如果你有一个 .env
文件,其中包含 GROQ_API_KEY=your_key_here
,调用 load_dotenv()
后,你可以直接通过 os.environ.get("GROQ_API_KEY")
获取这个值。
4. 获取 Groq API Key
# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
说明:
这段代码判断环境变量中是否存在 GROQ_API_KEY
。如果没有找到,就调用 getpass.getpass()
来提示用户输入API密钥,并将输入的值设置到环境变量中。
举例:
当你第一次运行程序时,如果未在 .env
中设置 GROQ_API_KEY
,程序会暂停并提示“Enter API key for Groq:”,输入后就会将该值存入 os.environ
中,后续代码就能使用这个密钥了。
5. 定位并加载 SQLite 数据库
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"]) # 注意需要传递列表
print(f"\n Original table info: {table_info}")
输出:
说明:
here("data/Chinook.db")
:利用pyprojroot.here
方法定位项目中data
目录下的Chinook.db
文件,返回其绝对路径。SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
:通过传入 SQLite 的 URI 构建一个 SQLDatabase 对象,用于后续的查询操作。db.get_table_info(["Album"])
:获取数据库中Album
表的详细信息,这里传入的是列表格式,因为该方法可能支持多个表一次查询。print(...)
:将获取到的表信息打印出来,便于调试和确认数据加载成功。
举例:
假设 Chinook.db
是一个包含音乐相关数据的数据库,Album
表记录了专辑信息。打印出来的信息可能包含专辑名称、发行年份等字段,方便后续构造查询问题。
6. 初始化 Llama 模型(使用 Groq 后端)
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
说明:
这行代码调用 init_chat_model
初始化一个聊天模型。
- 参数
"llama-3.3-70b-specdec"
指定了模型名称,代表一种 Llama 模型的特定版本。 model_provider="groq"
表明该模型运行在 Groq 的后端上。temperature=0
表示模型的输出是确定性的(温度为0时,生成的内容不会有随机性),适合需要精确答案的场景。
举例:
当你向该模型提出问题时,由于温度为0,模型每次都会给出相同的回答,非常适合生成SQL查询这类需要准确答案的任务。
7. 创建生成SQL查询的链(Chain)
write_chain = create_sql_query_chain(llm, db)
response = write_chain.invoke({"question": "What name of MediaType is?"})
print(response,'\n')
说明:
create_sql_query_chain(llm, db)
:利用前面初始化的 LLM 和 数据库对象 创建一个“SQL查询链”。这个链负责将自然语言问题转换为 SQL 查询语句。write_chain.invoke({"question": "What name of MediaType is?"})
:调用该链,将问题(自然语言形式)传递进去,链内部会调用 LLM 根据数据库结构生成对应的 SQL 查询。- 打印返回的
response
,查看生成的SQL查询语句。
举例:
假设 MediaType
是数据库中的一个表或字段,该问题 可能 转换成类似 SELECT Name FROM MediaType
的 SQL 查询语句。程序将输出这个生成的查询。注意:这里用的是可能,为啥?因为有些模型会输出其他解释和重复问题的内容,这样的内容是不能直接执行的,必须得是干净的SQL语句,不然就报语法错误。
比如可能生成这种:红色圈出来的就是我们不想要的内容
8. 构建验证链:检查和修正SQL查询的错误
system = """Double check the user's {dialect} query for common mistakes, including:
- Only return SQL Query not anything else like ```sql ... ```
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates\
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query.
If there are no mistakes, just reproduce the original query with no further commentary.
Output the final SQL query only."""
说明:这里是重点,没有这个提示后面执行 create_sql_query_chain(llm, db)
输出结果会掉进坑里,本人已经掉过,还好,我爬起来了。
这段字符串定义了一个系统级别的提示(system prompt),用于检查和验证生成的 SQL 查询是否存在常见错误。提示内容中列出了几种常见的SQL错误(例如:使用 NOT IN
处理 NULL
、错误地使用 UNION
等),并要求如果发现错误则重写查询,否则直接输出原查询。
举例:
如果生成的SQL语句包含了多余的标记或格式错误,这个提示会要求模型对查询进行调整,确保最终输出的是纯 SQL 查询文本。
9. 构造聊天提示模板
prompt = ChatPromptTemplate.from_messages(
[("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)
说明:
- 这里使用
ChatPromptTemplate.from_messages
创建了一个聊天式的提示模板,其中包含两部分消息:- 系统消息(system):上面定义的用于验证SQL查询的提示。
- 人类消息(human):包含占位符
{query}
,将由用户的问题替换。
.partial(dialect=db.dialect)
将数据库的方言信息(如 SQLite、MySQL 等)传递进去,使得提示中{dialect}
占位符可以被正确替换。
举例:
如果数据库是 SQLite,db.dialect
可能返回 "sqlite"
,那么在提示模板中 {dialect}
就会被替换为 "sqlite"
,确保模型知道当前使用的是哪种SQL语法。
10. 构建验证链
validation_chain = prompt | llm | StrOutputParser()
说明:
这行代码使用管道(|
)运算符将几个组件组合成一个验证链:
- 先通过
prompt
生成提示, - 再由 LLM 生成回复,
- 最后使用
StrOutputParser()
将回复解析为纯文本字符串。
整个链的作用就是对原生成的SQL查询进行验证和可能的重写。
举例:
例如,如果原SQL查询为 SELECT * FROM MediaType
,但存在语法问题,这个验证链可能会返回修改后的正确查询,如 SELECT Name FROM MediaType
。
11. 将生成SQL查询链和验证链组合成完整的链
full_chain = {"query": write_chain} | validation_chain
query = full_chain.invoke(
{
"question": "What name of MediaType is?"
}
)
print(query)
输出结果:
说明:
{"query": write_chain}
:这里将之前的生成SQL查询链(write_chain)包装成一个字典,其中键为"query"
,表示该链负责处理查询。- 然后使用管道运算符
|
将它与validation_chain
组合成一个完整的链。 - 最后,调用
full_chain.invoke(...)
,传入问题,整个链首先生成SQL查询,然后对生成的查询进行验证(如有错误则重写),最终输出经过验证的SQL语句。 - 打印最终的SQL查询。
举例:
整个过程:用户提问“MediaType的名称是什么?”
→ write_chain
根据问题和数据库信息生成SQL查询语句
→ validation_chain
检查查询语句是否存在语法或逻辑错误,并进行修正
→ 最终输出一个正确的SQL查询字符串,如 SELECT Name FROM MediaType;
。
12. 执行生成的SQL查询
db.run(query)
输出结果:
说明:
使用数据库对象 db
执行最终生成的SQL查询。
举例:
如果查询返回 SELECT Name FROM MediaType;
,则 db.run(query)
会连接到 Chinook.db
数据库,并执行该查询,返回结果。
总结
整段代码实现了以下工作流程:
- 初始化和配置:启用调试模式,加载环境变量和API密钥,定位数据库文件。
- 数据库连接:使用 SQLite 数据库
Chinook.db
,并获取表信息进行验证。 - 模型初始化:初始化一个基于 Llama 模型的聊天模型,并指定使用 Groq 后端。
- SQL查询链生成:利用自然语言问题生成SQL查询语句(例如“MediaType的名称是什么?”)。
- 验证与修正:对生成的SQL查询进行错误检查,修正常见问题,确保最终输出正确的SQL查询。
- 执行查询:将验证后的SQL查询发送给数据库执行,并返回查询结果。
实际应用场景举例:
如果你正在开发一个数据库管理工具,用户可以通过自然语言提问(如“列出所有专辑的名称”),系统将自动生成相应的SQL查询并返回结果。通过这种方式,非专业人士也可以通过简单的语言操作数据库,而不必掌握SQL语法。
最后给出完整代码:
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
load_dotenv()
# 如果没有设置 GROQ_API_KEY,则提示用户输入
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"]) # 注意需要传递列表
print(f"\n Original table info: {table_info}")
# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
# 定义自定义提示模板,用于生成 SQL 查询
custom_prompt = PromptTemplate(
input_variables=["dialect", "input", "table_info", "top_k"],
template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
write_query = create_sql_query_chain(llm, db,prompt=custom_prompt)
# 构造输入数据字典,其中包含方言、表结构、问题和行数限制
input_data = {
"dialect": db.dialect, # 数据库方言,如 "sqlite"
"table_info": db.get_table_info(), # 表结构信息
"question": "What name of MediaType is?",
"top_k": 5
}
# 调用链生成 SQL 查询,返回结果为一个字典,包含键 "query"
write_query_response = write_query.invoke(input_data)
print('\n write_query result:',write_query_response)
#执行SQL语句
execute_query = QuerySQLDataBaseTool(db=db)
execute_response = execute_query.invoke(write_query_response)
print('\n execute_response result:',execute_response)
#两个动作合起来搞成链
chain = write_query | execute_query
result_chain = chain.invoke(input_data)
print('\n result_chain==',result_chain)