朋友们,simbert模型是一个较好的相似句检索模型,但是在大规模检索中,需要实现快速检索,这个时候离不开milvus等向量检索库,下面用实际代码来讲一下simbert之milvus应用。
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
import tensorflow as tf
from openapi_server.models.sentence_schema import SentenceSchema
from openapi_server.models.QaVecSchema import QaVecSchema
import connexion
from mysql_tool.connection import DBHelper
from config.loadconfig import get_logger
from milvus import Milvus, IndexType, MetricType, Status
import random
from bert4keras.snippets import sequence_padding
from apscheduler.schedulers.background import BackgroundScheduler
import datetime
import os
logger = get_logger(__name__)
global graph
graph = tf.get_default_graph()
sess = keras.backend.get_session()
# 获取绝对目录上上级目录
upper2path = os.path.abspath(os.path.join(os.getcwd()))
# bert配置
config_path = "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/bert_config.json"
checkpoint_path = "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/bert_model.ckpt"
dict_path = "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/vocab.txt"
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
# 建立加载模型
bert = build_transformer_model(
config_path,
checkpoint_path,
with_pool='linear',
application='unilm',
return_keras_model=False,
)
# 加载编码器
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
向量入库:
def qa2vecs():
collection_reconstruct()
data = qa_query()
milvus, collection_name = MilvusHelper().connection()
param = {
'collection_name': collection_name,
'dimension': 384,
'index_file_size': 256, # optional
'metric_type': MetricType.IP # optional
}
milvus.create_collection(param)
vecs = []
ids = []
progress_idx = 0
with sess.as_default():
with graph.as_default():
for record in data:
progress_idx += 1
token_ids, segment_ids = tokenizer.encode(record["text"])
vec = encoder.predict([[token_ids], [segment_ids]])[0]
vecs.append(vec)
ids.append(record["id"])
if (len(ids) % 5000 == 0 or progress_idx == len(data)) and len(ids) > 0:
logger.info("data sync :{:.2f}%".format(progress_idx * 100.0 / len(data)))
milvus.insert(collection_name=collection_name, records=vecs_normalize(vecs), ids=ids, params=param)
vecs = []
ids = []
milvus.close()
return progress_idx
上面的向量入库的时候,文本的id和text都存了,milvus里面有id->text的向量,所以最终检索的时候,能够同时拿到vector和id,然后id去mysql里面找即可。