参考:
https://milvus.io/docs/hybrid_search_with_milvus.md
milvus使用不方便:
1)离线计算向量很慢BGEM3EmbeddingFunction
2)milvus安装环境支持很多问题,不支持windows、centos等
在线demo:
https://colab.research.google.com/drive/1OGvOyJH2NUQB1Ft3rqAFLQ_5Dzi8aHk_?usp=sharing
暂时只能在linux ubuntu或mac上使用,因为windows、centos暂时不支持milvus_lite:https://github.com/milvus-io/milvus/issues/34854
安装:
pip install --upgrade pymilvus "pymilvus[model]" -i https://pypi.tuna.tsinghua.edu.cn/simple
environs-9.5.0 milvus-lite-2.4.10 milvus-model-0.2.5 onnxruntime-1.16.3 pymilvus-2.4.6
数据:
wget http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv
代码:
##data
import pandas as pd
file_path = "quora_duplicate_questions.tsv"
df = pd.read_csv(file_path, sep="\t")
questions = set()
for _, row in df.iterrows():
obj = row.to_dict()
questions.add(obj["question1"][:512])
questions.add(obj["question2"][:512])
if len(questions) > 500: # Skip this if you want to use the full dataset
break
docs = list(questions)
print(docs[0])
# bge m3
from milvus_model.hybrid import BGEM3EmbeddingFunction
ef = BGEM3EmbeddingFunction(use_fp16=False, device="cpu")
dense_dim = ef.dim["dense"]
docs_embeddings = ef(docs)
##创建向量库
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
connections.connect(uri="./milvus.db")
fields = [
# Use auto generated id as primary key
FieldSchema(
name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100
),
# Store the original text to retrieve based on semantically distance
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512),
# Milvus now supports both sparse and dense vectors,
# we can store each in a separate field to conduct hybrid search on both vectors
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=dense_dim),
]
schema = CollectionSchema(fields)
col_name = "hybrid_demo"
if utility.has_collection(col_name):
Collection(col_name).drop()
col = Collection(col_name, schema, consistency_level="Strong")
sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}
col.create_index("sparse_vector", sparse_index)
dense_index = {"index_type": "AUTOINDEX", "metric_type": "IP"}
col.create_index("dense_vector", dense_index)
col.load()
#插入数据
for i in range(0, len(docs), 50):
batched_entities = [
docs[i : i + 50],
docs_embeddings["sparse"][i : i + 50],
docs_embeddings["dense"][i : i + 50],
]
col.insert(batched_entities)
print("Number of entities inserted:", col.num_entities)
##数据查询
from pymilvus import (
AnnSearchRequest,
WeightedRanker,
)
def dense_search(col, query_dense_embedding, limit=10):
search_params = {"metric_type": "IP", "params": {}}
res = col.search(
[query_dense_embedding],
anns_field="dense_vector",
limit=limit,
output_fields=["text"],
param=search_params,
)[0]
return [hit.get("text") for hit in res]
def sparse_search(col, query_sparse_embedding, limit=10):
search_params = {
"metric_type": "IP",
"params": {},
}
res = col.search(
[query_sparse_embedding],
anns_field="sparse_vector",
limit=limit,
output_fields=["text"],
param=search_params,
)[0]
return [hit.get("text") for hit in res]
def hybrid_search(
col,
query_dense_embedding,
query_sparse_embedding,
sparse_weight=1.0,
dense_weight=1.0,
limit=10,
):
dense_search_params = {"metric_type": "IP", "params": {}}
dense_req = AnnSearchRequest(
[query_dense_embedding], "dense_vector", dense_search_params, limit=limit
)
sparse_search_params = {"metric_type": "IP", "params": {}}
sparse_req = AnnSearchRequest(
[query_sparse_embedding], "sparse_vector", sparse_search_params, limit=limit
)
rerank = WeightedRanker(sparse_weight, dense_weight)
res = col.hybrid_search(
[sparse_req, dense_req], rerank=rerank, limit=limit, output_fields=["text"]
)[0]
return [hit.get("text") for hit in res]
query = input("Enter your search query: ")
print(query)
query_embeddings = ef([query])
dense_results = dense_search(col, query_embeddings["dense"][0])
sparse_results = sparse_search(col, query_embeddings["sparse"][[0]])
hybrid_results = hybrid_search(
col,
query_embeddings["dense"][0],
query_embeddings["sparse"][[0]],
sparse_weight=0.7,
dense_weight=1.0,
)