1.上传文本并将文本向量化
import os
from django.conf import settings
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import DashScopeEmbeddings
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader
from pymilvus import MilvusClient, DataType
def get_embedding(text):
file_path = os.path.join(settings.BASE_DIR, 'media', text)
loader = ''
# 判断文件类型
if file_path.endswith('.txt'):
loader = TextLoader(file_path, encoding='utf-8')
elif file_path.endswith('.pdf'):
loader = PyPDFLoader(file_path)
elif file_path.endswith('.docx'):
loader = Docx2txtLoader(file_path)
elif file_path.endswith('.doc'):
loader = Docx2txtLoader(file_path)
elif file_path.endswith('.md'):
loader = TextLoader(file_path)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=150, chunk_overlap=50)
chunks = text_splitter.split_documents(docs)
# embeddings = DashScopeEmbeddings(model='text-embedding-v1')
# db = Chroma.from_documents(chunks, embeddings, persist_directory='./chroma_db')
# db.persist()
client = MilvusClient(uri="http://127.0.0.1:19530")
schema = MilvusClient.create_schema(
auto_id=True,
enable_dynamic_field=True,
)
schema.add_field(field_name="f_id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="f_vector", datatype=DataType.FLOAT_VECTOR, dim=1536)
schema.add_field(field_name="f_content", datatype=DataType.VARCHAR, max_length=5000)
# 准备索引参数对象,用于定义集合中的字段索引
index_params = client.prepare_index_params()
# 为字段f_id添加索引,索引类型为STL_SORT
index_params.add_index(
field_name="f_id",
index_type="STL_SORT"
)
# 为字段f_vector添加索引,索引类型为IVF_FLAT,距离度量类型为IP,并设置nlist参数
index_params.add_index(
field_name="f_vector",
index_type="IVF_FLAT",
metric_type="IP",
params={"nlist": 128}
)
# 创建集合t_file,指定其schema,并设置索引参数
client.create_collection(
collection_name="app02_file",
schema=schema,
index_params=index_params
)
chunk_list = list(map(lambda x: x.page_content, chunks))
embedding = DashScopeEmbeddings()
chunk_embeds = embedding.embed_documents(chunk_list)
data = []
for i in range(len(chunk_embeds)):
data_row = {"f_vector": chunk_embeds[i], "f_content": chunk_list[i]}
data.append(data_row)
print(data)
client.insert(
collection_name="app02_file",
data=data
)
return docs
2.将问题向量化并搜索,拼接prompt,使用Agent(代理)
from langchain.agents import tool
from langchain.agents import load_tools
from langchain.agents import AgentType
from langchain.agents import initialize_agent
class testView(APIView):
def get(self, request):
query = request.query_params.get('query', None)
llm = Tongyi()
@tool('t1')
def t1(query: str):
"""当你需要查询数据库时才会使用这个工具"""
client = MilvusClient(uri="http://127.0.0.1:19530", db_name="default")
embedding = DashScopeEmbeddings(model="text-embedding-v1")
question = embedding.embed_query(query)
ret = client.search(
collection_name='app02_file',
data=[question],
limit=3, # Max. number of search results to return
output_fields=["f_content"]
)
promptTemplate = PromptTemplate.from_template(
"请根据下面内容总结回答\n{text}\n问题:{question}"
)
text = ""
for content in ret[0]:
text = text + content["entity"]["f_content"]
prompt = promptTemplate.format(text=text, question=query)
print(prompt)
return prompt
tools = [t1]
agent = initialize_agent(
tools,
llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
ret = agent.run(query)
return Response({'response': ret})