【工程实践】使用Roformer-sim(SimBERTv2 )做数据增强

news2025/1/8 11:48:37

前言

        此文仅记录以Roformer-sim为基础模型做数据增强的过程,Roformer-sim模型细节请移步:SimBERTv2来了!融合检索和生成的RoFormer-Sim模型 - 科学空间|Scientific Spaces

https://github.com/ZhuiyiTechnology/roformer-sim

1.功能介绍

        可以用作数据增强与文本相似度计算。

2.安装与使用

2.1 安装

        代码依赖的运行环境为 tensorflow 1.14、keras 2.2.5、bert4keras 0.10.6。

        由于代码的运行依赖于tensor flow1.14,与现有base环境中的tensorflow版本冲突,所以为了避免环境中各个库的互相影响,在实践时新建了conda 环境。

2.1.1 创建

conda create --name DataAug python=3.6

2.1.2 激活

conda activate DataAug

2.1.3 下载依赖项

pip install tensorflow==1.14 -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install keras==2.2.5 -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install bert4keras==0.10.6 -i https://pypi.tuna.tsinghua.edu.cn/simple/

2.1.4 下载资源

git clone https://github.com/ZhuiyiTechnology/roformer-sim

2.1.5 下载模型

2.2 使用

2.2.1 运行generate.py

        将其中的模型文件路径替换,运行generate.py即可。

2.2.2 问题

        假设报错:TensorFlow binary was not compiled to use: AVX2 FMA,那么在开头添加:

import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['KERAS_BACKEND']='tensorflow'

2.2.3 开启服务

        使用flask和tensorflow进行配合开启服务时,会报错。后期选择使用FastAPI开启服务。

#! -*- coding: utf-8 -*-
# RoFormer-Sim base 基本例子
# 测试环境:tensorflow 1.14 + keras 2.3.1 + bert4keras 0.10.6
#import flask
#from flask import request
import numpy as np
import json
import uvicorn
from fastapi import FastAPI

from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import sequence_padding, AutoRegressiveDecoder
from bert4keras.snippets import uniout
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['KERAS_BACKEND']='tensorflow'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
maxlen = 64

# 模型配置
config_path = '/ssd/dongzhenheng/Pretrain_Model/chinese_roformer-sim-char-ft_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/ssd/dongzhenheng/Pretrain_Model/chinese_roformer-sim-char-ft_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/ssd/dongzhenheng/Pretrain_Model/chinese_roformer-sim-char-ft_L-12_H-768_A-12/vocab.txt'
            

# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)  # 建立分词器

# 建立加载模型
roformer = build_transformer_model(
    config_path,
    checkpoint_path,
    model='roformer',
    application='unilm',
    with_pool='linear'
)

encoder = keras.models.Model(roformer.inputs, roformer.outputs[0])
seq2seq = keras.models.Model(roformer.inputs, roformer.outputs[1])


class SynonymsGenerator(AutoRegressiveDecoder):
    """seq2seq解码器
    """
    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, step):
        token_ids, segment_ids = inputs
        token_ids = np.concatenate([token_ids, output_ids], 1)
        segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
        return self.last_token(seq2seq).predict([token_ids, segment_ids])

    def generate(self, text, n=1, topp=0.95, mask_idxs=[]):
        token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
        for i in mask_idxs:
            token_ids[i] = tokenizer._token_mask_id
        output_ids = self.random_sample([token_ids, segment_ids], n,
                                        topp=topp)  # 基于随机采样
        return [tokenizer.decode(ids) for ids in output_ids]


synonyms_generator = SynonymsGenerator(
    start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen
)


def gen_synonyms(text, n,k, mask_idxs):
    ''''含义: 产生sent的n个相似句,然后返回最相似的k个。
    做法:用seq2seq生成,并用encoder算相似度并排序。
    '''
    r = synonyms_generator.generate(text, n, mask_idxs=mask_idxs)
    r = [i for i in set(r) if i != text]
    r = [text] + r
    X, S = [], []
    for t in r:
        x, s = tokenizer.encode(t)
        X.append(x)
        S.append(s)
    X = sequence_padding(X)
    S = sequence_padding(S)
    Z = encoder.predict([X, S])
    Z /= (Z**2).sum(axis=1, keepdims=True)**0.5
    argsort = np.dot(Z[1:], -Z[0]).argsort()
    return [r[i + 1] for i in argsort[:k]]

app = FastAPI()
@app.get("/sentence/{sentence}")
async def get_item(sentence:str):
    resp = {}
    result = gen_synonyms(text=sentence,n=50,k=20,mask_idxs=[])
    resp['result'] = result
    return json.dumps(resp, ensure_ascii=False)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=9516)
#发送请求
prefix_url = 'http://0.0.0.0:9516/DataAugment?'
params = {
      'sentence':text
         }
response = requests.get(prefix_url,params)
result_list = response.json()['resp']

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/853052.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[LeetCode - Python]344.反转字符串(Easy);345. 反转字符串中的元音字母(Easy);977. 有序数组的平方(Easy)

1.题目 344.反转字符串(Easy) 1.代码 class Solution:def reverseString(self, s: List[str]) -> None:"""Do not return anything, modify s in-place instead."""# 双指针left,right 0, len(s)-1while left < right:temp s[left]s[…

利用PCL实现点云配准

一、介绍 This document demonstrates using the Iterative Closest Point algorithm in your code which can determine if one PointCloud is just a rigid transformation of another by minimizing the distances between the points of two pointclouds and rigidly tran…

Ishikawa

Ishikawa 石川、鱼骨头、因果 其实我压根不知道 Ishikawa 这个日文就是石川&#xff0c;^_^&#xff0c;视乎也没啥影响

fastadmin动态获取单选框选中值修改页面

需求场景&#xff1a; 在编辑页面中&#xff0c;要求要根据某一单选框&#xff08;字段名称popup&#xff09;的选中值&#xff0c;来动态显示或者隐藏某个div&#xff08;idupload_img&#xff09;。 edit: function () {var popVal $("input[typeradio][namerow[popup]…

CNN(四):ResNet与DenseNet结合--DPN

&#x1f368; 本文为&#x1f517;365天深度学习训练营中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊|接辅导、项目定制 前面实现了ResNet和DenseNet的算法&#xff0c;了解了它们有各自的特点&#xff1a; ResNet&#xff1a;通过建立前面层与后面层之间的“短路…

U8g2 驱动oled自定义中文字库

原文&#xff1a;Arduino驱动LED128X64 - U8g2 参考&#xff1a; Arduino通过u8g2库驱动OLED适合 u8g2 的中文字体&#xff0c;采用文泉驿点阵宋体作为源本&#xff0c;提供 12x12、13x13、14x14、15x15 和 16x16 点阵字库。 本文所需工具下载 我们在项目中大概率会遇到LED显示…

淘宝商品详情(API接口)各大请求示例参考参数

主图&#xff0c;标题&#xff0c;价格&#xff0c;销量&#xff0c;库存&#xff0c;sku&#xff0c;详情信息&#xff0c;促销价&#xff0c;优惠券信息&#xff0c;券后价等 请求示例&#xff1a; <?php// 请求示例 url 默认请求参数已经URL编码处理 // 本示例代码未加…

一文走进时序数据库性能测试工具 TSBS

一、背景 在物联网、车联网等时序数据场景中&#xff0c;数据的高速写入能力至关重要&#xff0c;会对产品方案的可用性、可靠性和扩展性产生影响。 以物联网为例&#xff0c;当面临千万甚至上亿设备、平均每个设备采集几十个到几百个指标时&#xff0c;每秒生成的数据将达到…

使用JavaScript开发网页地图导航

使用JavaScript开发网页地图导航 导航是生活中的一个常见需求&#xff0c;而在互联网时代&#xff0c;网页地图导航成为了人们获取信息和帮助的重要工具。在网页中开发一个地图导航功能&#xff0c;能够提供用户位置定位、路线规划、交通情况等有用的信息&#xff0c;提供便利…

企业权限管理(五)-订单分页

订单分页查询 PageHelper介绍 PageHelper是国内非常优秀的一款开源的mybatis分页插件&#xff0c;它支持基本主流与常用的数据库&#xff0c;例如mysql、oracle、mariaDB、DB2、SQLite、Hsqldb等。 PageHelper使用 集成 引入分页插件有下面2种方式&#xff0c;推荐使用 Maven …

【Redis】Spring/SpringBoot 操作 Redis Java客户端

目录 操作 Redis Java客户端SpringBoot 操作Redis 步骤 操作 Redis Java客户端 1.Jedis 2.Lettuce(主流) <-Spring Data Redis SpringBoot 操作Redis 步骤 1.添加Redis 驱动依赖 2.设置Redis 连接信息 spring.redis.database0 spring.redis.port6379 spring.redis.host…

售后服务行业呼叫中心系统解决方案

随着社会经济的不断发展和消费者对售后服务需求的不断提高&#xff0c;在售后服务行业中&#xff0c;越来越多的企业使用呼叫中心系统来帮助企业提高售后服务质量和效率&#xff0c;提高客户满意度&#xff0c;增强企业竞争力。 一、呼叫中心系统的定义和功能 呼叫中心系统是指…

[mongo]性能机制,分析工具

性能机制 应用端 应用端-选择节点 对于复制集读操作&#xff0c;选择哪个节点是由readPreference决定的 primary/primaryPreferredsecondary/secondaryPreferrednearest 如果不希望一个远距离节点被选择 将它设置为隐藏节点通过标签&#xff08;Tag&#xff09;控制可选的节点…

分享Java技术下AutojsPro7云控代码

引言 有图有真相&#xff0c;那短视频就更是真相了。下面是三大语言的短视频。 Java源码版云控示例&#xff1a; Java源码版云控示例在线视频 核心技术&#xff1a;各个编程语言的WebSocket技术。 Java&#xff1a;Nettey、Net&#xff1a;Fleck、Python&#xff1a;Tornad…

Android Framework底层原理之WMS的启动流程

一 概述 今天&#xff0c;我们介绍 WindowManagerService&#xff08;后续简称 WMS&#xff09;的启动流程&#xff0c;WMS 是 Android 系统中&#xff0c;负责窗口显示的的服务。在 Android 中它也起着承上启下的作用。 如下图&#xff0c;就是《深入理解 Android》书籍中的…

模拟实现消息队列项目(完结) -- 基于MQ的生产者消费者模型

目录 前言 1. 生产者 2. 消费者 3. 启动消息队列服务器 4. 运行效果 结语 前言 在上一章节,我们完成了消息队列的客户端部分,至此我们整个消息队列项目就构建完成了,那我们做的这个消息队列到底有什么效果,以及如何去使用我们自己的消息队列呢?那么本文,就将我们的MQ进行实战操…

GSEA富集分析结果详解

1. GSEA富集分析原理图 2. GSEA富集分析过程 1. 计算富集分数&#xff08;ES&#xff09; 富集分数&#xff1a;S 反应基因集&#xff08;比如某个通路内的基因集&#xff09;成员 s 在排序基因集 L&#xff08;比如根据 logFC 排序的差异基因集&#xff0c;默认降序&#xf…

“为爱起航,一村一书院”在阳朔落地

2023年8月1-5 日&#xff0c;“关爱祖国下一代&#xff0c;助力乡村振兴” 之为爱起航项目在阳朔举行。 本次活动由千里思乡村振兴促进会联合中国文化交流大使组委会携同大湾区19位师生加入到首批“为爱起航&#xff0c;一村一书院”项目中&#xff0c;同时&#xff0c;本项目得…

分页查询从接口到实现,统一对日期类型进行格式化处理

编写Service实现类编写Mapper的sql&#xff0c;但复杂的sql语句需要写到mapper对应的xml文件中日期类型格式化处理 /*** 扩展springmvc框架的消息转换器* param converters*/Overrideprotected void extendMessageConverters(List<HttpMessageConverter<?>> conve…

初识Container

1. 什么是Container&#xff08;容器&#xff09; 要有Container首先要有Image&#xff0c;也就是说Container是通过image创建的。 Container是在原先的Image之上新加的一层&#xff0c;称作Container layer&#xff0c;这一层是可读可写的&#xff08;Image是只读的&#xff0…