1、安装milvus
下载milvus-standalone-docker-compose.yml并保存为docker-compose.yml
wget https://github.com/milvus-io/milvus/releases/download/v2.3.2/milvus-standalone-docker-compose.yml -O docker-compose.yml
运行milvus
sudo docker-compose up -d
2、文档预处理
import os
import re
import jieba
import torch
import pandas as pd
from pymilvus import utility
from pymilvus import connections, CollectionSchema, FieldSchema, Collection, DataType
from transformers import AutoTokenizer, AutoModel
connections.connect(
alias="default",
host='localhost',
port='19530'
)
# 定义集合名称和维度
collection_name = "document"
dimension = 768
docs_folder = "./knowledge/"
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModel.from_pretrained("bert-base-chinese")
# 获取文本的向量
def get_vector(text):
input_ids = tokenizer(text, padding=True, truncation=True, return_tensors="pt")["input_ids"]
with torch.no_grad():
output = model(input_ids)[0][:, 0, :].numpy()
return output.tolist()[0]
def create_collection():
# 定义集合字段
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True, description="primary id"),
FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=50),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=10000),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
]
# 定义集合模式
schema = CollectionSchema(fields=fields, description="collection schema")
# 创建集合
if utility.has_collection(collection_name):
# 如果你想继续添加新的文档可以直接 return。但你想要重新创建collection,就可以执行下面的代码
# return
utility.drop_collection(collection_name)
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
# 创建索引
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
collection.create_index(field_name="vector", index_params=default_index)
print(f"Collection {collection_name} created successfully")
else:
collection = Collection(name=collection_name, schema=schema, using='default', shards_num=2)
# 创建索引
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 2048}, "metric_type": "IP"}
collection.create_index(field_name="vector", index_params=default_index)
print(f"Collection {collection_name} created successfully")
def init_knowledge():
collection = Collection(collection_name)
# 遍历指定目录下的所有文件,并导入到 Milvus 集合中
docs = []
for root, dirs, files in os.walk(docs_folder):
for file in files:
# 只处理以 .txt 结尾的文本文件
if file.endswith(".txt"):
file_path = os.path.join(root, file)
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# 对文本进行清洗处理
content = re.sub(r"\s+", " ", content)
title = os.path.splitext(file)[0]
# 分词
words = jieba.lcut(content)
# 将分词后的文本重新拼接成字符串
content = " ".join(words)
# 获取文本向量
vector = get_vector(title + content)
docs.append({"title": title, "content": content, "vector": vector})
# 将文本内容和向量通过 DataFrame 一起导入集合中
df = pd.DataFrame(docs)
collection.insert(df)
print("Documents inserted successfully")
if __name__ == "__main__":
create_collection()
init_knowledge()
3、知识库匹配
通过向量索引库计算出与问题最为相似的文档
import torch
from document_preprocess import get_vector
from pymilvus import Collection
collection = Collection("document") # Get an existing collection.
collection.load()
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# 定义查询函数
def search_similar_text(input_text):
# 将输入文本转换为向量
input_vector = get_vector(input_text)
# 查询前三个最匹配的向量ID
similarity = collection.search(
data=[input_vector],
anns_field="vector",
param={"metric_type": "IP", "params": {"nprobe": 10}, "offset": 0},
limit=3,
expr=None,
consistency_level="Strong"
)
ids = similarity[0].ids
# 通过ID查询出对应的知识库文档
res = collection.query(
expr=f"id in {ids}",
offset=0,
limit=3,
output_fields=["id", "content", "title"],
consistency_level="Strong"
)
print(res)
return res
if __name__ == "__main__":
question = input('Please enter your question: ')
search_similar_text(question)
4、完成回答
from transformers import AutoModel, AutoTokenizer
from knowledge_query import search_similar_text
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval()
def predict(input, max_length=2048, top_p=0.7, temperature=0.95, history=[]):
res = search_similar_text(input)
prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。
已知内容:
{res}
问题:
{input}
"""
query = prompt_template
for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html
from knowledge_query import search_similar_text
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
model = model.eval()
is_knowledge = True
"""Override Chatbot.postprocess"""
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = "<br>"+line
text = "".join(lines)
return text
def predict(input, chatbot, max_length, top_p, temperature, history):
global is_knowledge
chatbot.append((parse_text(input), ""))
query = input
if is_knowledge:
res = search_similar_text(input)
prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。
如果无法从中得到答案,请说 "当前会话仅支持解决一个类型的问题,请清空历史信息重试",不允许在答案中添加编造成分,答案请使用中文。
已知内容:
{res}
问题:
{input}
"""
query = prompt_template
is_knowledge = False
for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
temperature=temperature):
chatbot[-1] = (parse_text(input), parse_text(response))
yield chatbot, history
def reset_user_input():
return gr.update(value='')
def reset_state():
global is_knowledge
is_knowledge = False
return [], []
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">ChatGLM</h1>""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(share=False, inbrowser=True)