前言
- 训练文本相似度数据集并进行评估:sentence-transformers(SBert)
- 预训练模型:chinese-roberta-wwm-ext
- 数据集:蚂蚁金融文本相似度数据集
- 前端:Vue2+elementui+axios
- 后端:flask
训练模型
- 创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
- 获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
- 计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
- 模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。
首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。
这部分在详细代码里注释得很全。
后端部分
使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。
from sentence_transformers import SentenceTransformer, util
# 后端接口
from flask import Flask, jsonify, request
import re
# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容
app = Flask(__name__)
# 使通过jsonify返回的中文显示正常,否则显示为ASCII码
app.config["JSON_AS_ASCII"] = False
model_path = 'D:/xxx模型路径/'
model = SentenceTransformer(model_path)
@app.route("/evaluate",methods=['POST'])
def evalute_sentence():
s1 = request.json.get("s1")
s2 = request.json.get("s2")
if s1 and s2:
embedding1 = model.encode(s1, convert_to_tensor=True)
embedding2 = model.encode(s2, convert_to_tensor=True)
similarity = util.cos_sim(embedding1, embedding2).tolist()
return jsonify({"code": 200, "msg": "预测成功", "data": similarity})
else:
return jsonify({"code": 400, "msg": "缺少字段"})
if __name__ == '__main__':
app.run(debug=True)
前端部分
框架使用Vue2,UI框架使用elementui。组件校验用户输入的表单(内容为中文,字数限制32个字,两个句子不为空),只有符合规则的字段才能提交表单。将数据通过Axios调用接口传递给后端,再根据后端接口响应状态进行相应的处理,如果返回状态码200,说明接口调用成功,展示返回的预测值,否则调用失败,页面弹出失败消息提示。
<template>
<div class="recommend">
<el-card class="box">
<h2 class="title">中文文本相似度预测</h2>
<el-form :model="evaluateForm" :rules="evaluateRules" ref="evaluateForm" class="form">
<el-form-item prop="s1">
<el-input
placeholder="请输入句子一"
maxlength="32"
show-word-limit
v-model="evaluateForm.s1"
autocomplete="false"
prefix-icon="el-icon-edit-outline"
></el-input>
</el-form-item>
<el-form-item prop="s2">
<el-input
maxlength="32"
placeholder="请输入句子二"
v-model="evaluateForm.s2"
show-word-limit
autocomplete="false"
prefix-icon="el-icon-edit-outline"
></el-input>
</el-form-item>
<el-form-item class="btn-container">
<el-button
type="primary"
@click="submitForm('evaluateForm')"
class="btn"
id="queryButton"
>开始预测</el-button>
</el-form-item>
</el-form>
<div v-show="result" style="margin-top: 20px">
<el-progress
:text-inside="true"
:stroke-width="26"
:percentage="result*100 ? result*100 : 0"
class="el-bg-inner-running"
></el-progress>
<p>预测结果:{{result}}</p>
</div>
</el-card>
</div>
</template>
<script>
import api from "@/api/index"
export default {
data () {
return {
evaluateForm: {
s1: "",
s2: ""
},
evaluateRules: { // 评估表单校验规则
s1: [
{ required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
],
s2: [
{ required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
],
},
result: undefined,
}
},
methods: {
postEvaluate () { // 调用接口
api.postEvaluate(this.evaluateForm)
.then((res) => {
if (!res) {
return
}
console.log("res", res)
if (res.data.code !== 200) {
this.$message({
message: "请求失败",
type: "error"
})
return
}
let data = res.data.data[0]
this.result = data[0]
console.log("this.result", this.result)
this.$message({
message: "预测成功!",
type: "success"
})
})
.catch((error) => {
this.$message.error('资源获取错误!')
})
},
submitForm (formName) { // 提交表单
this.$refs[formName].validate((valid) => {
if (valid) {
this.postEvaluate()
} else {
this.$message({
message: "请按要求填写",
type: "warning"
})
console.log('error in submit form')
return false
}
})
document.getElementById("queryButton").blur()
},
}
}
</script>
<style lang="scss" scoped>
.recommend {
width: 100%;
height: 100%;
text-align: center;
display: flex;
text-align: center;
flex-direction: column;
align-items: center;
justify-content: center;
overflow: hidden;
background: #00416a 0 / cover fixed; /* fallback for old browsers */
background: -webkit-linear-gradient(
to right,
#00416a,
#e4e5e6
); /* Chrome 10-25, Safari 5.1-6 */
background: linear-gradient(
to right,
#00416a,
#e4e5e6
); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
.box {
width: 48%;
height: 60%;
position: relative;
background: hsla(0, 0%, 100%, 0.3);
z-index: 5;
padding: 10px 20px;
// display: flex;
// flex-direction: column;
// justify-content: center;
box-sizing: border-box;
&::before {
content: '';
position: absolute;
top: 0;
right: 0;
bottom: 0;
left: 0;
filter: blur(20px);
}
.title {
color: #143b54;
}
.btn-container {
margin: 10px auto;
.btn {
width: 100%;
border-radius: 20px;
}
}
}
}
::v-deep .el-card {
border: 0;
box-shadow: 0 5px 16px 0 rgb(0 0 0 / 30%);
}
::v-deep .el-progress-bar__outer {
border: 0;
background-color: transparent;
// background-color: #abcbe0;
}
::v-deep .el-bg-inner-running .el-progress-bar__inner {
background: #9cecfb; /* fallback for old browsers */
background: -webkit-linear-gradient(
to left,
#0052d4,
#65c7f7,
#9cecfb
); /* Chrome 10-25, Safari 5.1-6 */
background: linear-gradient(
to left,
#0052d4,
#65c7f7,
#9cecfb
); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
}
</style>
预训练模型比较
paraphrase-multilingual-MiniLM-L12-v2
参数设置:epochs=1,batch_size=16
特点:作为sbert官方多语言预训练模型,已带有BERT层和池化层,可直接用数据评估,但未经纯中文文本训练,准确率较低
chinese-electra-180g-small-discriminator
参数设置:epochs=1, batch_size=16
特点:运行时间快,准确率尚可
chinese-electra-180g-small-discriminator
参数设置:epochs=20, batch_size=16
特点:20次迭代比1次迭代有效果,但差别不大
chinese-electra-180g-small-discriminator
参数设置:epochs=1,batch_size=8
特点:比batch_size=16时效果更好
chinese-roberta-wwm-ext
参数设置:epochs=1,batch_size=8
特点:迭代1次和20次准确率无差别,稳定且效果在所有模型中最好,缺点是体积大运行速度慢
最后
代码已上传至sbert中文文本相似度预测,欢迎star!