bert-NER 转化成 onnx 模型

news2024/12/23 6:16:08

保存模型

加载模型

from transformers import AutoTokenizer, AutoModel, AutoConfig

NER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)
ner_model = AutoModelForTokenClassification.from_pretrained(NER_MODEL_PATH)
ner_model.eval()

测试ner效果

在这里插入图片描述

测试速度

在这里插入图片描述

导出到onnx

# !pip install onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple/

# 导出 onnx 模型
import onnxruntime
from itertools import chain
from transformers.onnx.features import FeaturesManager

config = ner_config
tokenizer = ner_tokenizer
model = ner_model
output_onnx_path = "bert-ner.onnx"

onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')

torch.onnx.export(
    model,
    (dummy_inputs,),
    f=output_onnx_path,
    input_names=list(onnx_config.inputs.keys()),
    output_names=list(onnx_config.outputs.keys()),
    dynamic_axes={
        name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    },
    do_constant_folding=True,
    use_external_data_format=onnx_config.use_external_data_format(model.num_parameters()),
    enable_onnx_checker=True,
    opset_version=onnx_config.default_onnx_opset,
)

加载ONNX模型

自定义pipeline

from onnxruntime import SessionOptions, GraphOptimizationLevel, InferenceSession

class PipeLineOnnx:
    def __init__(self, tokenizer, onnx_path, config):
        self.tokenizer = tokenizer
        self.config = config  # label2id, id2label
        options = SessionOptions() # initialize session options
        options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        # 设置线程数
#         options.intra_op_num_threads = 4
        # 这里的路径传上一节保存的onnx模型地址
        self.session = InferenceSession(
            onnx_path, sess_options=options, providers=["CPUExecutionProvider"]
        )
        # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
        self.session.disable_fallback() 

    def __call__(self, text):
        inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        ids = inputs["input_ids"]
        inputs_offset = self.tokenizer.encode_plus(text, return_offsets_mapping=True).offset_mapping
        inputs_detach = {k: v.detach().cpu().numpy() for k, v in inputs.items()}

        # 运行 ONNX 模型
        # 这里的logits要有export的时候output_names相对应

        output = self.session.run(output_names=['logits'], input_feed=inputs_detach)[0]
        logits = torch.tensor(output)

        num_labels = len(self.config.label2id)
        active_logits = logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)
        softmax = torch.softmax(active_logits, axis=1)
        scores = torch.max(softmax, axis=1).values.cpu().detach().numpy()
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size*seq_len,) - predictions at the token level

        tokens = self.tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
        token_predictions = [self.config.id2label[i] for i in flattened_predictions.cpu().numpy()]
        wp_preds = list(zip(tokens, token_predictions)) # list of tuples. Each tuple = (wordpiece, prediction)

        ner_result = [{"index": idx, "word":i,"entity":j, "start": k[0], "end": k[1], "score": s} for idx, (i,j,k,s) in enumerate(zip(tokens, token_predictions, inputs_offset, scores)) if j != 'O']
        return post_process(ner_result)
        

def allow_merge(a, b):
    a_flag, a_type = a.split('-')
    b_flag, b_type = b.split('-')
    if b_flag == 'B' or a_flag == 'E':
        return False
    if a_type != b_type:
        return False
    if (a_flag, b_flag) in [
        ("B", "I"),
        ("B", "E"),
        ("I", "I"),
        ("I", "E")
    ]:
        return True
    return False

def divide_entities(ner_results):
    divided_entities = []
    current_entity = []

    for item in sorted(ner_results, key=lambda x: x['index']):
        if not current_entity:
            current_entity.append(item)
        elif allow_merge(current_entity[-1]['entity'], item['entity']):
            current_entity.append(item)
        else:
            divided_entities.append(current_entity)
            current_entity = [item]
    divided_entities.append(current_entity)
    return divided_entities

def merge_entities(same_entities):
    def avg(scores):
        return sum(scores)/len(scores)
    return {
        'entity': same_entities[0]['entity'].split("-")[1],
        'score': avg([e['score'] for e in same_entities]),
        'word': ''.join(e['word'].replace('##', '') for e in same_entities),
        'start': same_entities[0]['start'],
        'end': same_entities[-1]['end']
    }

def post_process(ner_results):
    return [merge_entities(i) for i in divide_entities(ner_results)]

加载模型

from transformers import AutoTokenizer, AutoConfig

NER_MODEL_PATH = './save_model'
ner_tokenizer = AutoTokenizer.from_pretrained(NER_MODEL_PATH)
ner_config = AutoConfig.from_pretrained(NER_MODEL_PATH)

pipe2 = PipeLineOnnx(ner_tokenizer, "bert-ner.onnx", config=ner_config)

测试效果

在这里插入图片描述

测试速度

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1658692.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【基于Ubuntu下Yolov5的目标识别】保姆级教程 | 虚拟机安装 - Ubuntu安装 - 环境配置(Anaconda/Pytorch/Vscode/Yolov5) |全过程图文by.Akaxi

目录 一.【YOLOV5算法原理】 1.输入端 2.Backbone 3.Neck 4.输出端 二.【系统环境】 1.虚拟机的安装与创建 2.安装Ubuntu操作系统 3.环境的配置 3.1.Ubuntu下Anacoda安装以及虚拟环境配置 3.2.Pytorch安装 3.3.Vscode安装 3.4.Yolov5源码及环境获取安装…

深入探索MySQL:成本模型解析与查询性能优化

码到三十五 : 个人主页 在数据库管理系统中,查询优化器是一个至关重要的组件,它负责将用户提交的SQL查询转换为高效的执行计划。在MySQL中,查询优化器使用了一个称为“成本模型”的机制来评估不同执行计划的优劣,并选择…

Windows 跨服务器进行 MYSQL备份脚本

Windows 服务器进行 MYSQL备份的脚本,使用该脚本前,请先测试一下 1、新建一个文本文档 2、将下面代码放入文本文档中,保存退出 echo off :: 命令窗口名 title mysql-bak:: 参数定义 set "Y%date:~,4%" set "m%date:~5,2%&qu…

深入理解网络原理5----HTTP协议

文章目录 一、HTTP协议格式二、HTTP请求2.1 URL 基本格式2.2 URL encode2.3 "方法" (method)2.4 认识请求 "报头" (header) 三、HTTP 响应3.1 "状态码" (status code) 四、HTPPS工作过程(经典面试题) 提示:以下…

数据结构——希尔排序

基本思想: 希尔排序法又称缩小增量法。希尔排序法的基本思想是:先选定一个整数,把待排序文件中所有记录分成个组,所有距离为的记录分在同一组内,并对每一组内的记录进行排序。然后,取,重复上述…

翼支付——风控场景中图模型的范式变迁

目录 风控图深度学习模型 风控图大模型

7-zip下载、安装

7-Zip 官方中文网站 (sparanoid.com) 7-Zip - 程序下载 (sparanoid.com)

idea使用git不提示账号密码登录,而是输入token问题解决

idea 或者 webstream 等全家桶软件 使用git 推送代码时,不提示账号密码登录,而是输入token问题解决 你的代码仓库是gitlab 然后打开修改代码后推送时,会默认使用gitlab插件,所以提示数据token 解决方式就是把gitlab插件取消使用这…

Kafka从0到消费者开发

安装ZK Index of /zookeeper/zookeeper-3.9.2 下载安装包 一定要下载-bin的,不带bin的是源码,没有编译的,无法执行。-bin的才可以执行。 解压 tar -zxvf apache-zookeeper-3.9.2-bin.tar.gz 备份配置 cp zoo_sample.cfg zoo_sample.cfg-b…

Angular入门

Angular版本:Angular 版本演进史概述-天翼云开发者社区 - 天翼云 安装nodejs:Node.js安装与配置环境 v20.13.1(LTS)-CSDN博客 Angular CLI是啥 Angular CLI 是一个命令行接口(Angular Command Line Interface),是开发 Angular 应用的最快、最…

【Java基础】时间相关的类

需要引入类包java.util.Date* 用构造方法得到当前时间 Date d new Date(); 用构造方法将long类型的时间值转成日期 Date dd new Date(time); 得到当前时间的毫秒值 Long System.currentTimeMillis() 把Date类型转为long类型 long getTime() Date d new Date();//重写了…

亚马逊云科技中国峰会:与你开启云计算与前沿技术的探索之旅

亚马逊云科技中国峰会:与你开启云计算与前沿技术的探索之旅 Hello,我是科技博主Maynor,非常高兴地向你们推荐亚马逊云科技中国峰会,这是一场将于 5 月 29 日至 30 日在上海世博中心举办的科技盛会,如果你对云计算、行业发展新趋势…

探索人工智能的深度神经网络:理解、应用与未来

深度神经网络(DNNs)是一种人工智能模型,其灵感来自于人脑神经元之间的连接。它们由多个层次组成,每一层都包含多个神经元,这些神经元通过权重连接在一起。信息通过网络的输入层传递,并经过一系列隐藏层&…

【Leetcode每日一题】 分治 - 交易逆序对的总数(难度⭐⭐⭐)(74)

1. 题目解析 题目链接:LCR 170. 交易逆序对的总数 这个问题的理解其实相当简单,只需看一下示例,基本就能明白其含义了。 2.算法原理 归并排序的基本思路 归并排序将数组从中间分成两部分,在排序的过程中,逆序对的来…

如何把公章盖在电子档文件上?

将公章盖在电子档文件上,尤其是确保其法律效力和安全性,通常涉及以下步骤: 准备工作 获取合法的电子公章:确保你拥有公司或机构正式授权的电子公章图像,且该图像经过了必要的加密或数字签名处理,以确保其…

特征提取与深度神经网络(二)

关键点/角点检测 2011论文-ORB关键点检测,比SIFT与SURF速度更快。 ORB算法可以看出两个部分组成:快速关键点定位BRIEF描述子生成 Fast关键点检测: 选择当前像素点P,阈值T,周围16个像素点,超过连续N12个像素…

SparkSQL概述

1.1. SparkSQL介绍 SparkSQL,就是Spark生态体系中的构建在SparkCore基础之上的一个基于SQL的计算模块。SparkSQL的前身不叫SparkSQL,而是叫做Shark。最开始的时候底层代码优化、SQL的解析、执行引擎等等完全基于Hive,总是Shark的执行速度要比…

SpringCloud:认识微服务

程序员老茶 🙈作者简介:练习时长两年半的Java up主 🙉个人主页:程序员老茶 🙊 P   S : 点赞是免费的,却可以让写博客的作者开心好久好久😎 📚系列专栏:Java全栈&#…

让数据更「高效」一点!IvorySQL在Neon平台上的迅速部署和灵活应用

IvorySQL本身就是一个100%兼容PostgreSQL最新内核的开源数据库系统,而Neon Autoscaling Platform通常支持多种数据库和应用程序。将IvorySQL集成到该平台后,可以进一步增强与其他系统和应用程序的兼容性,同时更全面的体验IvorySQL的Oracle兼容…

lint 代码规范,手动修复,以及vscode的第三方插件eslint自动修复

ESlint代码规范 不是语法规范,是一种书写风格,加多少空格,缩进多少,加不加分号,类似于书信的写作格式 ESLint:是一个代码检查工具,用来检查你的代码是否符合指定的规则(你和你的团队可以自行约定一套规则)…