LangChain基础知识大全
- 一、部署ChatGLM-6B
- 1.拉取源码
- 2.安装环境
- 3.下载模型
- 4.修改api.py配置
- 5.运行api.py
- 二、Models组件
- 1.LLM(大语言模型)
- 2.Chat Model(聊天模型)
- 3.Embedding Model(嵌入模型)
- 3.1 下载中文文本向量模型
- 3.2 安装包
- 三、Prompts组件
- 1.zore-shot提示
- 2.few-shot提示
- 四、Chains组件
- 1.单chain
- 2.多chain
- 五、Agents组件
- 六、Memory组件
- 1.非持久化存储
- 2.持久化存储
- 3.ChatGLM+Memory
- 七、indexes组件
- 1.文档加载器
- 2.文档分割器
- 2.1 CharacterTextSplitter
- 2.2 其他分割器
- 3.VectorStores
- 3.1 VectorStore汇总
- 3.2 Chroma
- 3.3 FAISS
- 八、ChatGLM+Langchain
- 1.后端逻辑
- 2.添加前端交互
- 3.添加上传文档功能
一、部署ChatGLM-6B
- 后续会使用到
1.拉取源码
git clone https://github.com/THUDM/ChatGLM-6B
cd ChatGLM-6B
2.安装环境
conda create --name chatglm python=3.10
conda activate chatglm
pip install -r requirements.txt
# 安装cuda版本的pytorch(需要查看自己cuda的版本)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
3.下载模型
# 安装modelscope社区的包
pip install modelscope
- 在ChatGLM-6B的git目录运行以下python代码
from modelscope import snapshot_download
model_dir = snapshot_download('ZhipuAI/ChatGLM-6B', cache_dir='.', revision='master')
下载完后,需要将ChatGLM-6B改成chatglm-6b,否则会报错
4.修改api.py配置
- 指定模型的路径
注意:如果CPU/GPU内存不够,那么可以在实例化model的时候,添加
.quantize(4)
支持4和8的量化
5.运行api.py
- 这个是git源码的文件
- python api.py
二、Models组件
- LangChain目前支持三种模型类型:LLMs、ChatModels(聊天模型)、EmbeddingsModels(嵌入模型)。
- LLMs:大语言模型接收文本字符作为输入,返回的也是文本字符。
- 聊天模型:基于LLMs,不同的是它接收聊天消息(一种特定格式的数据)作为输入,返回的也是聊天消息。
- 文本嵌入模型:文本嵌入模型接收文本作为输入,返回的是浮点数列表。
1.LLM(大语言模型)
from langchain_community.llms import ChatGLM
llm = ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
response = llm("生成一个关于创业励志的故事")
print(response)
2.Chat Model(聊天模型)
3.Embedding Model(嵌入模型)
3.1 下载中文文本向量模型
- modelscope地址:https://modelscope.cn/models/AI-ModelScope/bge-large-zh
from modelscope import snapshot_download
model_dir = snapshot_download('AI-ModelScope/bge-large-zh', cache_dir='.', revision='master')
3.2 安装包
pip install sentence_transformers
pip install langchain==0.2.15
pip install langchain-community==0.2.15```
### 3.3 示例代码
```python
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
model_name = r"C:\practice\python\langchain\ChatGLM-6B\AI-ModelScope\bge-large-zh"
model_kwargs = {'device': 'cuda:0'}
encode_kwargs = {'normalize_embeddings': True}
model = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
# 将多段文本向量化,
print(model.embed_documents(["一只乌鸦在喝水"]))
# 返回:[[0.008352028205990791, 0.01422210969030857, -0.010817999951541424, .....], .....]
# 将一个文本向量化,
print(model.embed_query("一只乌鸦在喝水"))
# 返回:[0.015122170560061932, 0.0075872535817325115, ....]
三、Prompts组件
- Prompt是指当用户输入信息给模型时加入的提示,这个提示的形式可以是zero-shot或者few-shot等方式,目的是让模型理解更为复杂的业务场景以便更好的解决问题。
- 提示模板:如果你有了一个起作用的提示,你可能想把它作为一个模板用于解决其他问题,LangChain就提供了PromptTemplates组件,它可以更方便的构建提示。
1.zore-shot提示
- 零样本学习(zore-shot),是指模型在没有任何相关任务示例的情况下,仅根据任务的描述来生成响应。这种学习方式要求模型具备强大的泛化能力和对任务描述的深刻理解。
from langchain_community.llms import ChatGLM
from langchain_core.prompts import PromptTemplate
# 定义模板
template = "给我出一道关于{subject}的题目"
prompt = PromptTemplate(input_variables=["subject"], template=template)
prompt_text = prompt.format(subject="数学")
# prompt_text = "给我出一道关于数学的题目"
llm = ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
response = llm(prompt_text)
print(response)
2.few-shot提示
- 少样本学习(few-shot),是指通过给模型提供少量(通常是1到几个)的示例来帮助模型理解任务,并生成正确的响应。这种方式可以显著提高模型在新任务上的表现,同时减少对数据量的需求。
from langchain_community.llms import ChatGLM
from langchain_core.prompts import PromptTemplate, FewShotPromptTemplate
examples = [
{"word": "大", "antonym": "小"},
{"word": "上", "antonym": "下"},
{"word": "左", "antonym": "右"},
]
example_template = """
单词:{word} 反义词:{antonym}\\n
"""
# 定义模板
example_prompt = PromptTemplate(input_variables=["word", "antonym"], template=example_template)
few_shot_prompt = FewShotPromptTemplate(
examples=examples, # 模型训练的案例
example_prompt=example_prompt, # 样例的模板
prefix="给出每个单词的反义词", # 提示的前缀
suffix="单词:{input}\\n反义词:", # 提示的后缀
input_variables=["input"], # 在few-shot当中定义的变量
example_separator="\\n", # 样例之间都使用换行进行隔开
)
# 格式化文本
prompt_text = few_shot_prompt.format(input="粗")
# 打印的结果如下:
# 给出每个单词的反义词\n
# 单词:大 反义词:小\n
# \n
# 单词:上 反义词:下\n
# \n
# 单词:左 反义词:右\n
# \n单词:粗\n反义词:
llm = ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
response = llm(prompt_text)
print(response)
四、Chains组件
1.单chain
from langchain.chains.llm import LLMChain
from langchain_community.llms import ChatGLM
from langchain_core.prompts import PromptTemplate
# 1.定义模板
template = "给我出一道关于{subject}题目"
prompt = PromptTemplate(input_variables=["subject"], template=template)
# 2.链条
llm = ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
chain = LLMChain(llm=llm, prompt=prompt)
# 3.执行chain(将传入的参数,作为最开始的提示词,传入到llm第一个节点)
result = chain.run("数学加减乘除根号的")
print("result: ", result)
2.多chain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SimpleSequentialChain
from langchain_community.llms import ChatGLM
from langchain_core.prompts import PromptTemplate
llm = ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
# 创建第一个链条
first_template = "给我出一道关于{subject}题目"
first_prompt = PromptTemplate(input_variables=["subject"], template=first_template)
first_chain = LLMChain(llm=llm, prompt=first_prompt)
# 创建第二个链条
second_template = "请解答题目:{subject2}"
second_prompt = PromptTemplate(input_variables=["subject2"], template=second_template)
second_chain = LLMChain(llm=llm, prompt=second_prompt)
# 连接两条链,verbose=True 可以显示推理过程
overall_chain = SimpleSequentialChain(chains=[first_chain, second_chain], verbose=True)
# 执行链,只需要传入第一个参数
result = overall_chain.run("数学加减乘除根号的")
print("result: ", result)
五、Agents组件
- 在LangChain 中 Agents的作用就是根据用户的需求,来访问一些第三方工具(比如:搜索引擎或者数据库),进而来解决相关需求问题。
- 为什么要借助第三方库?
- 因为大模型虽然非常强大,但是也具备一定的局限性,比如不能回答实时信息、处理数学逻辑问题仍然非常的初级等等。因此,可以借助第三方工具来辅助大模型的应用。
- Agent代理
- 制定计划和思考下一步需要采取的行动
- 负责控制整段代码的逻辑和执行,代理暴露了一个接口,用来接收用户输入,并返回AgentAction或AgentFinish。
- Toolkit工具包一些集成好了代理包,比如create_csv_agent 可以使用模型解读csv文件。
- Tool工具
- 解决问题的工具
- 第三方服务的集成,比如计算器、网络搜索(谷歌、bing)等等。
- AgentExecutor代理执行器:它将代理和工具列表包装在一起,负责迭代运行代理的循环,直到满足停止的标准。
from langchain.agents import load_tools, initialize_agent, AgentType
from langchain_community.llms import ChatGLM
from langchain_core.prompts import PromptTemplate
# 现在我们实现一个使用代理的例子:假设我们想查询一下中国目前有多少人口?我们可以使用多个代理工具,让Agents选择执行。
# 需要安装依赖库:
# pip install wikipedia
# pip install numexpr
# 加载内置工具 llm-math和 wikipedia
llm = ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
tools = load_tools(["llm-math", "wikipedia"], llm=llm)
agent = initialize_agent(tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True)
prompt_template = "哪个国家的人最多?"
prompt = PromptTemplate.from_template(prompt_template)
result = agent.run(prompt)
六、Memory组件
- 大模型本身不具备上下文的概念,它并不保存上次交互的内容,ChatGPT之所以能够和人正常沟通对话,因为它进行了一层封装,将历史记录回传给了模型。
- 因此LangChain也提供了Memory组件,Memory分为两种类型:
短期记忆
和长期记忆
。- 短期记忆:一般指单一会话时传递数据
- 长期记忆:则是处理多个会话时获取和更新信息。
- 目前的Memory组件只需要考虑ChatMessageHistory。举例分析:
1.非持久化存储
from langchain_community.chat_message_histories import ChatMessageHistory
history = ChatMessageHistory()
# 用户提问:吃了吗? 存储到历史记录对象
history.add_user_message("吃了吗?")
# AI回答:吃了... 存储到历史记录对象
history.add_ai_message("吃了...")
print(history.messages)
# [HumanMessage(content='吃了吗?'), AIMessage(content='吃了...')]
2.持久化存储
- 核心函数:
- message_to_dict:将历史记录转化为字典
- message_from_dict:将字典转化为历史记录
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.messages import messages_to_dict, messages_from_dict
history = ChatMessageHistory()
# 用户提问:吃了吗? 存储到历史记录对象
history.add_user_message("吃了吗?")
# AI回答:吃了... 存储到历史记录对象
history.add_ai_message("吃了...")
print(history.messages)
# [HumanMessage(content='吃了吗?'), AIMessage(content='吃了...')]
# 转化为字典
dicts = messages_to_dict(history.messages)
# 字典转化为历史聊天记录
history = messages_from_dict(dicts)
3.ChatGLM+Memory
from langchain.chains.conversation.base import ConversationChain
from langchain_community.llms.chatglm import ChatGLM
llm = ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
conversation = ConversationChain(llm=llm)
result1 = conversation.predict(input="我的语文是82")
print("result1: ", result1)
result2 = conversation.predict(input="我的数学88")
print("result2: ", result2)
result3 = conversation.predict(input="我的数学和语文成绩一共多少分?")
print("result3: ", result3)
# 最终效果:
# result1: 你的语文是82分,很高!是阅读、写作和口语方面的表现吗?
# result2: 你的数学是88分,也很不错!是选择题、填空题和解答题的表现吗?
# result3: 你的语文和数学成绩一共是170分。
七、indexes组件
- Indexes组件的目的是让LangChain具备处理文档处理的能力,包括:文档加载、检索等。这里的文档不局限于txt、pdf等文本类内容,还涵盖email、区块链、视频等内容。
- 文档加载器。
- 文本分割器
- VectorStores
- 检索器
1.文档加载器
- 文档加载器可以基于TextLoader包
- 文档加载器使用起来很简单,只需要引入相应的loader工具:
from langchain_community.document_loaders import TextLoader
loader = TextLoader("./test.txt", encoding="utf8")
doc = loader.load()
print("doc: ", doc)
print(len(doc))
print(doc[0].page_content[:10])
2.文档分割器
- 由于模型对输入的字符长度有限制,我们在碰到很长的文本时,需要把文本分割成多个小的文本片段。
- 文本分割最简单的方式是按照字符长度进行分割,但是这会带来很多问题,比如说如果文本是一段代码,一个函数被分割到两段之后就成了没有意义的字符,所以整体的原则是把语义相关的文本片段尽可能的放在一起。
- LangChain中最基本的文本分割器是CharacterTextSplitter,它按照指定的分隔符(默认"n\n")进行分割,并且考虑文本片段的最大长度。
2.1 CharacterTextSplitter
from langchain.text_splitter import CharacterTextSplitter
# 实例化一个文本分割对象
text_splitter = CharacterTextSplitter(
separator=" ", # 分隔符,案例空格分割(注:如果没有空格,那么会不会分割,直到找到空格为止)
chunk_size=5, # 指每个分割文本块的大小,案例为5个
chunk_overlap=2, # 每块之间重叠的字符。有重复的内容 才更好的衔接上下文
)
# 一句分割
text = text_splitter.split_text("a b c d e f g")
print(text)
# 多句话分割
texts = text_splitter.create_documents(["a b c d e f g", "h i j k l m n"])
2.2 其他分割器
文档分割器 | 描述 |
---|---|
LatexTextSplitter | 沿着Latex标题、标题、枚举等分割文本 |
MarkdownTextSplitter | 沿着Markdown的标题、代码块或水平规则来分割文本 |
TokenTextSplitter | 根据openAl的token数进行分割 |
PythonCodeTextSplitter | 沿着Python类和方法的定义分割文本 |
3.VectorStores
- Vectorstores是一种特殊类型的数据库,它的作用是存储由嵌入创建的向量,提供相似查询等功能。
3.1 VectorStore汇总
VectorStore | 描述 |
---|---|
Chroma | 一个开源嵌入式数据库 |
ElasticSearch | ElasticSearch |
Milvus | 用于存储、索引和管理由深度神经网络和其他机器学习(ML)模型产生的大量嵌入向量的数据库 |
Redis | 基于Redis的检索器 |
FAISS | Facebook AI相似性搜索服务 |
Pinecone | 一个具有广泛功能的向量数据库 |
3.2 Chroma
- 安装:pip install chromadb
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Chroma
def get_embedding_model():
model_name = r"C:\practice\python\langchain\ChatGLM-6B\AI-ModelScope\bge-large-zh"
model_kwargs = {'device': 'cuda:0'}
encode_kwargs = {'normalize_embeddings': True}
return HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
# 1.读取文档里面的内容
with open("./test.txt", encoding="utf8") as f:
doc_content = f.read()
# 2.切分文档
# 实例化一个文本分割对象
text_splitter = CharacterTextSplitter(
chunk_size=100,
chunk_overlap=20,
)
texts = text_splitter.split_text(doc_content)
# 3.将切分后的文档向量化并保存
embedding_model = get_embedding_model()
doc_search = Chroma.from_texts(texts=texts, embedding=embedding_model)
query = "我今年三十七岁。现在,我正坐在哪里?"
result = doc_search.similarity_search(query)
# result就是和query相似度最近的一段文本
print(result)
3.3 FAISS
- 安装包
pip install faiss-gpu # For CUDA 7.5+ Supported GPU's.
# OR
pip install faiss-cpu # For CPU Installation
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import Chroma, FAISS
def get_embedding_model():
model_name = r"C:\practice\python\langchain\ChatGLM-6B\AI-ModelScope\bge-large-zh"
model_kwargs = {'device': 'cuda:0'}
encode_kwargs = {'normalize_embeddings': True}
return HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
# 1.读取文档里面的内容
with open("./test.txt", encoding="utf8") as f:
doc_content = f.read()
# 2.切分文档
# 实例化一个文本分割对象
text_splitter = CharacterTextSplitter(
chunk_size=100,
chunk_overlap=20,
)
texts = text_splitter.split_text(doc_content)
# 3.将切分后的文档向量化并保存
embedding_model = get_embedding_model()
doc_search = FAISS.from_texts(texts=texts, embedding=embedding_model)
# 每次检索返回最相关的两个文档。默认是1个
doc_search = doc_search.as_retriever(search_kwargs={"k": 2})
query = "我今年三十七岁。现在,我正坐在哪里?"
result = doc_search.get_relevant_documents(query)
# result就是和query相似度最近的一段文本
print(result)
八、ChatGLM+Langchain
1.后端逻辑
import os
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.llms.chatglm import ChatGLM
from langchain_community.vectorstores import Chroma
def load_documents(directory="book"):
# 1.读取文档里面的内容
# 需要安装:pip install unstructured
# 在windows下面报错:failed to find libmagic
# 解决办法:pip uninstall python-magic
# pip install python-magic-bin==0.4.14
loader = DirectoryLoader(directory)
documents = loader.load()
# 2.切分文档
# 实例化一个文本分割对象
text_splitter = CharacterTextSplitter(
chunk_size=256,
chunk_overlap=20,
)
return text_splitter.split_documents(documents)
def get_embedding_model():
model_name = r"C:\practice\python\langchain\ChatGLM-6B\AI-ModelScope\bge-large-zh"
model_kwargs = {'device': 'cuda:0'}
encode_kwargs = {'normalize_embeddings': True}
return HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
def store_chroma(docs, embedding_model, persis_directory="VectorStore"):
if not os.path.exists(persis_directory):
# 实际测试,这里需要6秒
return Chroma.from_documents(docs, embedding_model, persist_directory=persis_directory).persist()
# 实际测试,这里仅需要0.4秒
return Chroma(persist_directory=persis_directory, embedding_function=embedding_model)
def load_chatglm_model():
return ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
persis_directory = "VectorStore"
documents = load_documents()
embedding_model = get_embedding_model()
db = store_chroma(documents, embedding_model)
llm = load_chatglm_model()
# 创建QA
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever()
)
# 提问
result = qa.run("我今年三十七岁。现在,我正坐在哪里?")
print(result)
# 回答:You are currently sitting in a波音747机舱.
# 问题:中英文混合回答,有问题(后面会解决)
2.添加前端交互
- 安装:pip install gradio
import os
import gradio as gr
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.llms.chatglm import ChatGLM
from langchain_community.vectorstores import Chroma
def load_documents(directory="book"):
# 1.读取文档里面的内容
# 需要安装:pip install unstructured
# 在windows下面报错:failed to find libmagic
# 解决办法:pip uninstall python-magic
# pip install python-magic-bin==0.4.14
loader = DirectoryLoader(directory)
documents = loader.load()
# 2.切分文档
# 实例化一个文本分割对象
text_splitter = CharacterTextSplitter(
chunk_size=256,
chunk_overlap=20,
)
return text_splitter.split_documents(documents)
def get_embedding_model():
model_name = r"C:\practice\python\langchain\ChatGLM-6B\AI-ModelScope\bge-large-zh"
model_kwargs = {'device': 'cuda:0'}
encode_kwargs = {'normalize_embeddings': True}
return HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
def store_chroma(docs, embedding_model, persis_directory="VectorStore"):
if not os.path.exists(persis_directory):
# 实际测试,这里需要6秒
return Chroma.from_documents(docs, embedding_model, persist_directory=persis_directory).persist()
# 实际测试,这里仅需要0.4秒
return Chroma(persist_directory=persis_directory, embedding_function=embedding_model)
def load_chatglm_model():
return ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
persis_directory = "VectorStore"
documents = load_documents()
embedding_model = get_embedding_model()
db = store_chroma(documents, embedding_model)
llm = load_chatglm_model()
# 创建QA
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever()
)
def chat(question, history):
return qa.run(question)
demo = gr.ChatInterface(chat)
demo.launch(inbrowser=True)
3.添加上传文档功能
import os
import time
import gradio as gr
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.llms.chatglm import ChatGLM
from langchain_community.vectorstores import Chroma
def load_documents(directory="book"):
# 1.读取文档里面的内容
# 需要安装:pip install unstructured
# 在windows下面报错:failed to find libmagic
# 解决办法:pip uninstall python-magic
# pip install python-magic-bin==0.4.14
loader = DirectoryLoader(directory)
documents = loader.load()
# 2.切分文档
# 实例化一个文本分割对象
text_splitter = CharacterTextSplitter(
chunk_size=256,
chunk_overlap=20,
)
return text_splitter.split_documents(documents)
def get_embedding_model():
model_name = r"C:\practice\python\langchain\ChatGLM-6B\AI-ModelScope\bge-large-zh"
model_kwargs = {'device': 'cuda:0'}
encode_kwargs = {'normalize_embeddings': True}
return HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
def get_chroma_db(embedding_model, persis_directory="VectorStore"):
return Chroma(persist_directory=persis_directory, embedding_function=embedding_model)
def store_chroma(docs, embedding_model, persis_directory="VectorStore"):
db = Chroma.from_documents(docs, embedding_model, persist_directory=persis_directory)
db.persist()
def load_chatglm_model():
return ChatGLM(
endpoint_url="http://127.0.0.1:8000",
max_token=80000,
top_p=0.9
)
persis_directory = "VectorStore"
documents = load_documents()
embedding_model = get_embedding_model()
# 加载chroma数据库
db = get_chroma_db(embedding_model=embedding_model)
# 存储数据
store_chroma(documents, embedding_model)
# 加载chatglm大模型
llm = load_chatglm_model()
# 创建QA
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever()
)
def add_text(history, text):
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def add_file(history, file):
# 拿到上传的文件夹
directory = os.path.dirname(file.name)
documents = load_documents(directory)
store_chroma(documents, embedding_model)
history = history + [(file.name, None)]
return history
def bot(history):
message = history[-1][0]
if os.path.sep in message and os.path.exists(message):
# 是上传文件
response = "文件上传成功"
else:
response = qa.run(message)
history[-1][1] = ""
for character in response:
history[-1][1] += character
time.sleep(0.05)
yield history
with gr.Blocks() as demo:
chat_bot = gr.Chatbot(
[],
elem_id="chatbot",
bubble_full_width=False,
avatar_images=(None, (os.path.dirname(__file__), r"C:\practice\python\langchain\ChatGLM-6B\icon.jpg"))
)
with gr.Row():
txt = gr.Textbox(
scale=4,
show_label=False,
placeholder="Enter text and press enter, or upload an image",
container=False
)
btn = gr.UploadButton("上传文档", file_types=['txt'])
txt_msg = txt.submit(add_text, [chat_bot, txt], [chat_bot, txt], queue=False).then(
bot, chat_bot, chat_bot
)
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
file_msg = btn.upload(add_file, [chat_bot, btn], [chat_bot], queue=False).then(
bot, chat_bot, chat_bot
)
demo.queue()
if __name__ == '__main__':
demo.launch()