sentence-transformers(SBert)中文文本相似度预测(附代码)

news2024/11/21 0:38:30

在这里插入图片描述

前言

  • 训练文本相似度数据集并进行评估:sentence-transformers(SBert)
  • 预训练模型:chinese-roberta-wwm-ext
  • 数据集:蚂蚁金融文本相似度数据集
  • 前端:Vue2+elementui+axios
  • 后端:flask

训练模型

  1. 创建网络:使用Sbert官方给出的预训练模型sentence_hfl_chinese-roberta-wwm-ext,先载入embedding层进行分词,再载入池化层并传入嵌入后的维度,对模型进行降维压缩,最后载入密集层,选择Than激活函数,输出维度大小为256维。
  2. 获取训练数据:构建出新模型后使用InputExample类存储训练数据,它接受文本对字符串列表和用于指示语义相似性的标签,用标准的Pytorch Dataloader包装train_examples,作用是打乱数据并生成特定大小的批次。
  3. 计算损失函数:对于每个句子对,通过网络传递句子A和句子B,从而产生嵌入u和v,使用余弦相似度计算相似性,并将结果与标准相似度得分进行比较。这样网络就能够进行微调,更好地识别句子的相似性。
  4. 模型调优:通过调用model.fit()来调优模型。向model.fit()中传递train_objective列表(由元组(dataloader, loss_function))组成。也可以传递多个元组,以便在具有不同损失函数的多个数据集上执行多任务学习。在训练过程需要使用sentence_transformers.evaluation评估表现是否有所改善,它包含各种可以传递给fit方法的evaluators。Evaluators会在训练期间定期运行,并且会返回分数,只有得分最高的模型才会存储在磁盘上。

首先运行preprocess.py获取数据,并划分训练集和测试集,之后运行train_sentence_bert.py,使用预训练模型, sbert将数据集用sbert训练相似度任务,得到训练好的模型,最后运行evaluate.py评估训练好的模型,将结果保存在predict.txt中,并输出预测结果。

这部分在详细代码里注释得很全。

后端部分

使用flask编写post接口,接收的数据格式为application/json,将前端传来的两个句子使用训练好的模型对其进行相似度预测,将得到的相似度类型从无法序列化存入json的tensor转成list,并将状态码,信息,数据返回给前端。

from sentence_transformers import SentenceTransformer, util
# 后端接口
from flask import Flask, jsonify, request
import re
# 用当前脚本名称实例化Flask对象,方便flask从该脚本文件中获取需要的内容
app = Flask(__name__)
# 使通过jsonify返回的中文显示正常,否则显示为ASCII码
app.config["JSON_AS_ASCII"] = False
model_path = 'D:/xxx模型路径/'
model = SentenceTransformer(model_path)
@app.route("/evaluate",methods=['POST'])
def evalute_sentence():
    s1 = request.json.get("s1")
    s2 = request.json.get("s2")
    if s1 and s2:
        embedding1 = model.encode(s1, convert_to_tensor=True)
        embedding2 = model.encode(s2, convert_to_tensor=True)
        similarity = util.cos_sim(embedding1, embedding2).tolist()
        return jsonify({"code": 200, "msg": "预测成功", "data": similarity})
    else:
        return jsonify({"code": 400, "msg": "缺少字段"})
if __name__ == '__main__':
    app.run(debug=True)

前端部分

框架使用Vue2,UI框架使用elementui。组件校验用户输入的表单(内容为中文,字数限制32个字,两个句子不为空),只有符合规则的字段才能提交表单。将数据通过Axios调用接口传递给后端,再根据后端接口响应状态进行相应的处理,如果返回状态码200,说明接口调用成功,展示返回的预测值,否则调用失败,页面弹出失败消息提示。

<template>
  <div class="recommend">
    <el-card class="box">
      <h2 class="title">中文文本相似度预测</h2>
      <el-form :model="evaluateForm" :rules="evaluateRules" ref="evaluateForm" class="form">
        <el-form-item prop="s1">
          <el-input
            placeholder="请输入句子一"
            maxlength="32"
            show-word-limit
            v-model="evaluateForm.s1"
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item prop="s2">
          <el-input
            maxlength="32"
            placeholder="请输入句子二"
            v-model="evaluateForm.s2"
            show-word-limit
            autocomplete="false"
            prefix-icon="el-icon-edit-outline"
          ></el-input>
        </el-form-item>
        <el-form-item class="btn-container">
          <el-button
            type="primary"
            @click="submitForm('evaluateForm')"
            class="btn"
            id="queryButton"
          >开始预测</el-button>
        </el-form-item>
      </el-form>
      <div v-show="result" style="margin-top: 20px">
        <el-progress
          :text-inside="true"
          :stroke-width="26"
          :percentage="result*100 ? result*100 : 0"
          class="el-bg-inner-running"
        ></el-progress>
        <p>预测结果:{{result}}</p>
      </div>
    </el-card>
  </div>
</template>

<script>
import api from "@/api/index"
export default {
  data () {
    return {
      evaluateForm: {
        s1: "",
        s2: ""
      },
      evaluateRules: { // 评估表单校验规则
        s1: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
        s2: [
          { required: true, message: '请输入中文句子', trigger: 'blur', pattern: /^[\u4E00-\u9FA5]+$/ },
        ],
      },
      result: undefined,
    }
  },
  methods: {
    postEvaluate () { // 调用接口
      api.postEvaluate(this.evaluateForm)
        .then((res) => {
          if (!res) {
            return
          }
          console.log("res", res)
          if (res.data.code !== 200) {
            this.$message({
              message: "请求失败",
              type: "error"
            })
            return
          }
          let data = res.data.data[0]
          this.result = data[0]
          console.log("this.result", this.result)
          this.$message({
            message: "预测成功!",
            type: "success"
          })

        })
        .catch((error) => {
          this.$message.error('资源获取错误!')
        })
    },
    submitForm (formName) { // 提交表单
      this.$refs[formName].validate((valid) => {
        if (valid) {
          this.postEvaluate()
        } else {
          this.$message({
            message: "请按要求填写",
            type: "warning"
          })
          console.log('error in submit form')
          return false
        }
      })
      document.getElementById("queryButton").blur()
    },
  }

}
</script>

<style lang="scss" scoped>
.recommend {
  width: 100%;
  height: 100%;
  text-align: center;
  display: flex;
  text-align: center;
  flex-direction: column;
  align-items: center;
  justify-content: center;
  overflow: hidden;
  background: #00416a 0 / cover fixed; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to right,
    #00416a,
    #e4e5e6
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
  .box {
    width: 48%;
    height: 60%;
    position: relative;
    background: hsla(0, 0%, 100%, 0.3);
    z-index: 5;
    padding: 10px 20px;
    // display: flex;
    // flex-direction: column;
    // justify-content: center;
    box-sizing: border-box;
    &::before {
      content: '';
      position: absolute;
      top: 0;
      right: 0;
      bottom: 0;
      left: 0;
      filter: blur(20px);
    }
    .title {
      color: #143b54;
    }
    .btn-container {
      margin: 10px auto;
      .btn {
        width: 100%;
        border-radius: 20px;
      }
    }
  }
}
::v-deep .el-card {
  border: 0;
  box-shadow: 0 5px 16px 0 rgb(0 0 0 / 30%);
}
::v-deep .el-progress-bar__outer {
  border: 0;
  background-color: transparent;
  // background-color: #abcbe0;
}
::v-deep .el-bg-inner-running .el-progress-bar__inner {
  background: #9cecfb; /* fallback for old browsers */
  background: -webkit-linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* Chrome 10-25, Safari 5.1-6 */
  background: linear-gradient(
    to left,
    #0052d4,
    #65c7f7,
    #9cecfb
  ); /* W3C, IE 10+/ Edge, Firefox 16+, Chrome 26+, Opera 12+, Safari 7+ */
}
</style>

预训练模型比较

paraphrase-multilingual-MiniLM-L12-v2
参数设置:epochs=1,batch_size=16
特点:作为sbert官方多语言预训练模型,已带有BERT层和池化层,可直接用数据评估,但未经纯中文文本训练,准确率较低

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=1, batch_size=16
特点:运行时间快,准确率尚可

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=20, batch_size=16
特点:20次迭代比1次迭代有效果,但差别不大

在这里插入图片描述

chinese-electra-180g-small-discriminator
参数设置:epochs=1,batch_size=8
特点:比batch_size=16时效果更好

在这里插入图片描述

chinese-roberta-wwm-ext
参数设置:epochs=1,batch_size=8
特点:迭代1次和20次准确率无差别,稳定且效果在所有模型中最好,缺点是体积大运行速度慢

在这里插入图片描述

最后

代码已上传至sbert中文文本相似度预测,欢迎star!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/162807.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

c语言公司考勤系统

1.要求 考勤系统是公司人事管理重要环节&#xff0c;用于记录员工迟到、早退、缺席、请假等出勤情况&#xff0c;并能提供数据统计功能。系统需求如下: 认证用户&#xff0c;如密码方式; 设置上下班时间&#xff0c;并能判断是否迟到、早退; 记录出勤状况&#xff0c;能记录每日…

基础IO(2)--文件描述符以及输入输出重定向

文件描述符fd 文件操作的本质是进程和被打开文件的关系。 进程可以打开多个文件&#xff0c;这些被打开的文件由OS管理&#xff0c;所以操作系统必定要为文件创建对应的内核数据结构标识文件–struct file{}【与C语言的FILE无关】 通过如下程序 #include <stdio.h> #…

uni-app在真机调试下兼容ethers的方法

目录 一、安装ethers 二、renderjs 三、注意事项 uni-app开发跨平台应用程序&#xff0c;项目搭建主要前端框是Uni-app Vue3 TS Vite&#xff0c;项目搭建参考文章Uni-app Vue3 TS Vite 创建项目 Hbuilderx版本是3.6.17 一、安装ethers yarn add ethers 如果像ether…

【Python】用xpath爬取2022热梗保存到txt中并生成词云

本文收录于《python学习笔记》专栏&#xff0c;这个专栏主要是我学习Python中遇到的问题&#xff0c;学习的新知识&#xff0c;或总结的一些知识点&#xff0c;我也是初学者&#xff0c;可能遇到的问题和大部分新人差不多&#xff0c;在这篇专栏里&#xff0c;我尽可能的分享出…

MySQL 索引 学习

索引 主键索引&#xff08;PRIMARY KEY&#xff09; 唯一标识&#xff0c;主键不可重复&#xff0c;只能有一个主键 唯一索引&#xff08;UNIQUE KEY&#xff09; 索引列 常规索引&#xff08;KEY/INDEX&#xff09;全文索引&#xff08;FullText&#xff09; 可以快速定位数据…

excel拆分技巧:如何快速对金额数字进行分列

金额数字分列&#xff0c;相信是做财务的小伙伴们经常遇到的问题。网上关于金额数字分列的方法很多&#xff0c;但用到的公式大都比较复杂。今天我们就来分享一个最简单的公式&#xff0c;仅用LEFT、RIGHT和COLUMN三个函数&#xff0c;就能达到效果&#xff01;在财务工作中&am…

Tapdata Cloud 场景通关系列:将数据导入阿里云 Tablestore,获得毫秒级在线查询和检索能力

【前言】作为中国的 “Fivetran/Airbyte”, Tapdata Cloud 自去年发布云版公测以来&#xff0c;吸引了近万名用户的注册使用。应社区用户上生产系统的要求&#xff0c;Tapdata Cloud 3.0 将正式推出商业版服务&#xff0c;提供对生产系统的 SLA 支撑。Tapdata 目前专注在实时数…

【论文阅读 CIKM2014】Extending Faceted Search to the General Web

文章目录ForewordMotivationMethodQuery facet generation:Facet feedbackEvaluationForeword This paper is from CIKM 2014, so we only consider the insightsI have read this paper last month and today i share this blogThere are many papers that have not been sha…

Docker网络原理详解

文章目录理解Docker0Docker 是如何处理容器网络访问的&#xff1f;Docker0网络模型图容器互联--Link自定义网络网络连通理解Docker0 查看本机IP ip addr1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000link/loopback 00…

application.properties的作用

springboot这个配置文件可以配置哪些东西 官方配置过多了解原理 这个properties文件其实是可以删掉的&#xff0c;官方是不推荐使用这个文件的&#xff0c;可以将其换成安排application.yaml。名字不能变&#xff0c;因为SpringBoot使用的是一个全局的配置文件 application.…

linux系统中使用QT实现CAN通信的方法

大家好&#xff0c;今天主要和大家分享一下&#xff0c;如何使用QT中的CAN Bus的具体实现方法。 目录 第一&#xff1a;CAN Bus的基本简介 第二&#xff1a;CAN通信应用实例 第三&#xff1a;程序的运行效果 第一&#xff1a;CAN Bus的基本简介 从QT5.8开始&#xff0c;提供…

C语言-柔性数组与几道动态内存相关的经典笔试题(12.2)

目录 思维导图&#xff1a; 1.柔性数组 1.1柔性数组的特点 1.2柔性数组的使用 1.3柔性数组的优势 2.几道经典笔试题 2.1题目1 2.2题目2 2.3题目3 2.4题目4 写在最后&#xff1a; 思维导图&#xff1a; 1.柔性数组 1.1柔性数组的特点 例&#xff1a; #include <…

javaEE 初阶 — java对于的操作文件

文章目录1. File 类概述2. 代码示例2.1 示例1&#xff1a;以绝对路径为例&#xff0c;演示获取文件路径2.2 示例2&#xff1a;以相对路径为例&#xff0c;演示获取文件路径2.3 示例3&#xff1a;测试文件是否存在、测试是不是文件、测试是不是目录2.4 示例4&#xff1a;创建文件…

27.函数指针变量的定义, 调用函数的方法,函数指针数组

函数指针变量的定义 返回值类型&#xff08;*函数指针变量名&#xff09;&#xff08;形参列表&#xff09;; int( *p )( int , int );//定义了一个函数指针变量p&#xff0c;p指向的函数必须有一个整型的返回值&#xff0c;有两个整型参数。 int max(int x, int y) { } int m…

AMR-IE:一种利用抽象语义表示(AMR)辅助图编码解码的联合信息抽取模型

Abstract Meaning Representation Guided Graph Encoding and Decoding for Joint Information Extraction 论文&#xff1a;2210.05958.pdf (arxiv.org) 代码&#xff1a;zhangzx-uiuc/AMR-IE: The code repository for AMR guided joint information extraction model (NAAC…

【学习笔记】【Pytorch】七、卷积层

【学习笔记】【Pytorch】七、卷积层学习地址主要内容一、卷积操作示例二、Tensor&#xff08;张量&#xff09;是什么&#xff1f;三、functional.conv2d函数的使用1.使用说明2.代码实现四、torch.Tensor与torch.tensor区别五、nn.Conv2d类的使用1.使用说明2.代码实现六、卷积公…

C/C++ noexcept NRVO

为什么需要noexcept为了说明为什么需要noexcept&#xff0c;我们还是从一个例子出发&#xff0c;我们定义MyClass类&#xff0c;并且我们先不对MyClass类的移动构造函数使用noexceptclass MyClass { public:MyClass(){}MyClass(const MyClass& lValue){std::cout << …

使用语雀绘制 Java 中六大 UML 类图

目录 下载语雀 泛化关系&#xff08;Generalization&#xff09; 实现关系&#xff08;Realization&#xff09; 关联关系&#xff08;Association&#xff09; 依赖关系&#xff08;Dependency&#xff09; 聚合关系&#xff08;Aggregation&#xff09; 组合关系&…

【Python学习】列表和元组

前言 前四天每天更新了小白看的基础教程 今天开始就更新一下&#xff0c;深入一点的知识点吧 还是老话&#xff1a;刚接触python的宝子可以点击文章末尾名片进行交流学习的哦 什么是列表和元组 列表是动态的&#xff0c;长度大小不固定&#xff0c;可以随意地增加、删减或…

【软件测试】软件测试基础2

1. 软件测试的生命周期 软件测试的生命周期&#xff1a; 需求分析→测试计划→ 测试设计、测试开发→ 测试执行→ 测试评估 ● 需求分析&#xff1a;站在用户的角度&#xff1a;查看需求逻辑是否正确&#xff0c;是否符合用户的需求和行为习惯&#xff1b;站在开发人员的角度&…