Pytorch版本的Ernie Health源码详解

news2025/2/23 23:20:10

Pytorch版本的Ernie Health源码详解

一、目录架构

在这里插入图片描述

二、尝试使用Ernie Health

import torch
# 查看torch版本
torch.__version__
'1.12.0+cpu'
# 查看设备是否有GPU资源
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device=', device)
# device= cpu
from transformers import AutoModel, AutoTokenizer,AutoConfig
# config配置模型的参数 从本地文件或URL加载预训练模型的配置文件
# 设置output_hidden_states=True 获取模型的所有隐藏状态 返回元组,第一个元素是最终的输出结果,后面的元素是每一层的隐藏状态。
config = AutoConfig.from_pretrained("premodel/", output_hidden_states=True)
# 根据给定的模型名称或路径自动选择对应的 tokenizer  AutoTokenizer.from_pretrained("model_name_or_path")
# model_name:bert-base-uncased、roberta-large   path:my_model
tokenizer = AutoTokenizer.from_pretrained("premodel/")
# 配置模型 利用config参数初始化
model = AutoModel.from_pretrained("premodel/", config=config)
# input_ids = torch.tensor([tokenizer.encode(text="welcome to ernie pytorch project", add_special_tokens=True)])
model.eval()
token = tokenizer.tokenize('我感觉头晕眼花')
input = tokenizer.encode('我感觉头晕眼花')
print(token)

在这里插入图片描述

三、使用Ernie Health完成文本分类

备注:这里只是想学习Ernie Health做文本分类的过程,数据集采用ChnSentiCorp,与医疗方向无关,后续将会分享Ernie Health在NLP医疗文本上的应用。

1.定义数据集(二分类)

# 加载数据
# load_dataset从数据集中心或本地文件系统中加载数据集。
# load_from_disk用于从本地磁盘中加载已经序列化的数据集,通常是使用 dataset.save_to_disk() 方法序列化后的数据集。
from datasets import load_dataset, load_from_disk
# 1. 定义数据集 二分类
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        #self.dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)
        self.dataset = load_from_disk('./dataset/ChnSentiCorp')[split]
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']
        return text, label
dataset = Dataset('train')

print('长度:', len(dataset))
print('dataset[0]:', dataset[0])

在这里插入图片描述
train中包含两个字段:text文本内容;label文本标签

2.加载字典和分词工具

from transformers import AutoTokenizer
# 2. 加载字典和分词工具
token = AutoTokenizer.from_pretrained('premodel')

3.定义批处理函数

# 3.定义批处理函数
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,  # 当句子长度大于max_length时截断操作
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt', # 返回pt类型可取值tf、pt、np分别对应tensorflow pytorch numpy张量,默认None为list
                                   return_length=True)
    print(data)
    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)
    labels = torch.LongTensor(labels).to(device)

    #print(data['length'], data['length'].max())
    # token_type_id 第一个句子和特殊符号为0,第二个句子为1
    # special_tokens_mask 特殊符号为1,其他位置为0
    return input_ids, attention_mask, token_type_ids, labels

4.数据加载器

# 4.数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16, # 每一次批次中包含16条数据
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    break
print(len(loader))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

5.加载预训练模型

from transformers import AutoModel
# 5.加载预训练模型
pretrained = AutoModel.from_pretrained('premodel')
#需要移动到cuda上
pretrained.to(device)

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

6.定义下游任务模型

# 6.定义下游任务模型
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2) # 单层神经网络模型

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids, # 利用预训练模型抽取文本特征
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)

        out = self.fc(out.last_hidden_state[:, 0]) # 抽取的文本特征放入全连接层中
        out = out.softmax(dim=1)
        return out
model = Model()
#同样要移动到cuda
model.to(device)
#虚拟一批数据,需要把所有的数据都移动到cuda上
input_ids = torch.ones(16, 100).long().to(device)
attention_mask = torch.ones(16, 100).long().to(device)
token_type_ids = torch.ones(16, 100).long().to(device)
labels = torch.ones(16).long().to(device)
#试算
model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape
# torch.Size([16, 500, 768])  batch_size,max_length,词编码的维度
# 后面的计算和中文分类完全一样,只是放在了cuda上计算

7.训练

from transformers import AdamW
######### 7.训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss() # 交叉熵损失

model.train()
for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 5 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)
        print(i, loss.item(), accuracy)
    if i == 100:
        break

8.测试

def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids,
            labels) in enumerate(loader_test):
        if i == 5:
            break
        print(i)
        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)
        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)
    print(correct / total)

test()

9.结果

0 0.6799753904342651 0.625
5 0.6905879974365234 0.4375
10 0.6816931366920471 0.5625
15 0.6723385453224182 0.625
20 0.6595535278320312 0.6875
25 0.6725921630859375 0.5
30 0.6362007856369019 0.8125
35 0.6231661438941956 0.9375
40 0.6399248242378235 0.625
45 0.6061363220214844 0.8125
50 0.6639376878738403 0.6875
55 0.6553252339363098 0.625
60 0.5932612419128418 0.9375
65 0.6260586380958557 0.625
70 0.6162962317466736 0.6875
75 0.6004247665405273 0.75
80 0.5238915681838989 0.9375
85 0.5489521622657776 0.875
90 0.5986231565475464 0.8125
95 0.5415515899658203 0.9375
100 0.6254498362541199 0.6875
0
1
2
3
4
测试准确率:0.75625

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/619414.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

I.MX6ULL_Linux_驱动篇(37) linux系统定时器

定时器是我们最常用到的功能,一般用来完成定时功能,本章我们就来学习一下 Linux 内核提供的定时器 API 函数,通过这些定时器 API 函数我们可以完成很多要求定时的应用。 Linux内核也提供了短延时函数,比如微秒、纳秒、毫秒延时函数…

Python selenium爬取影评生成词云图

文章目录 问题描述效果截图如下问题分析前期准备完整代码及解释字体素材 问题描述 通过中文分词、过滤停用词、生成词云图等步骤对评论数据进行处理和可视化。 效果截图如下 非常nice 问题分析 该程序需要使用 Selenium 库来模拟浏览器操作,因此需要下载安装 Chr…

@Autowired VS @Resource

一、两者的区别 首先,两者都是通过注解来实现依赖注入 。不同的话有以下几点: Autowired 是 Spring 提供的注解,所以只有 Spring 的 IoC容器 支持该注解。Resource 是 JSR-250 提供的(是 Java 的标准 ),我…

CnOpenData·A股上市公司标准数据

一、数据简介 按照《中华人民共和国标准化法》的定义,标准是指农业、工业、服务业以及社会事业等领域需要统一的技术要求。标准作为一种通用性的规范语言,在合理利用国家资源、保障产品质量、提高市场信任度、促进商品流通、维护公平竞争、保障安全等方面…

Hash算法的特点、应用和实现方法详解

什么是Hash算法?Hash算法,简称散列算法,也成哈希算法(英译),是将一个大文件映射成一个小串字符。与指纹一样,就是以较短的信息来保证文件的唯一性的标志,这种标志与文件的每一个字节…

企业数字化转型必看的6本书

导读 >> 2023年数据产业将为企业带来新的价值增量,成为企业数字化转型的重要突破口。数字化已经成为商业的一种基本常识,未来企业都将是数字化企业。然而在数字化转型话题热议的当下,真正成果显著的企业仍是少数,2023年企业…

如何阻止Windows Update更新Windows 10中的特定设备驱动程序

如果你想禁用Windows 10驱动程序的自动更新,那么方法有的是,但是如果你想禁用特定设备的驱动程序更新,该怎么办呢? 幸运的是,有一种替代方法可以禁用特定设备的驱动程序更新。你可以通过设置组策略“禁止安装与这些设备ID匹配的设备”来实现这一点。 根据微软的说法: …

在简历上写了“精通”后,我差点被面试官问到窒息....

前言 如果有真才实学,写个精通可以让面试官眼前一亮! 如果是瞎写?基本就要被狠狠地虐一把里! 最近在面试,我现在十分后悔在简历上写了“精通”二字… 先给大家看看我简历上的技能列表: 熟悉软件测试理…

阿里云服务器25565端口开通教程(ECS和轻量)

阿里云服务器25565端口怎么开通?ECS云服务器端口在安全组中开启,轻量应用服务器端口在防火墙中打开,我的世界mc服务器依赖25565端口,阿里云服务器网来详细说下云服务器ECS和轻量应用服务器开通25565端口的方法: 云服务…

学成在线项目note

目录 一、index.html 1、头部header 2、轮播图banner 3、精品推荐 4、精品推荐课程 5、footer 二、index.css 1、重要的代码 一、index.html <!-- 网站的首页, 所有网站的首页都叫index.html, 因为服务器找首页都是找index.html --> <!-- 布局: 从外到内, 从上到…

青岛科技大学|物联网工程|物联网定位技术(第三讲)|15:40

目录 物联网定位技术&#xff08;第三讲&#xff09; 1. 试简述C/A码的作用、构成 请画出C/A码生成电路简图并给予原理性的说明 2. 试简述 P码的作用、构成 请画出P码生成电路简图&#xff0c;并给予原理性的说明 3. GPS信号是如何进行伪码扩频与解扩 请画图给予说明 4…

Java的Object类和深拷贝和浅拷贝(面试题)

1.java.lang.Object类的说明 1.Object类是所有Java类的根父类 2.如果在类的声明中未使用extends关键字指明其父类&#xff0c;则默认父类为java.lang.Object类 3.Object类中的功能(属性、方法)就具通用性。 属性&#xff1a;无 方法&#xff1a;equals() / toString() / ge…

图片识别表格的方法有哪些?试试这几个好用的表格识别软件

随着数字化时代的到来&#xff0c;越来越多的公司和个人需要处理大量的表格数据。这些数据往往以图片的格式存在&#xff0c;而手动输入这些数据非常耗费时间和精力。因此&#xff0c;图片识别表格软件正在成为一个不可或缺的工具。那么&#xff0c;图片识别表格软件哪个好呢&a…

SAP从入门到放弃系列之CRP-Part1

从我学习CRP(Capacity Requirement planning)过程&#xff0c;应该能分三部分来总结。这篇就总结一下我学到的基本配置和概念。 温馨提示 &#xff1a;又臭又长的系统配置内容放在了最后的章节。本文分三个部分&#xff0c;工作中心数据和工艺路线创建&#xff0c;生产订单能力…

【Tomcat 部署及优化】

目录 一、Tomcat 安装部署1、Tomcat 介绍2、Tomcat 核心组件1、Tomcat 功能组件结构&#xff1a;2、Container 结构分析&#xff1a; 3、Tomcat 请求过程&#xff1a; 二、Tomcat 服务部署1.关闭防火墙&#xff0c;将安装 Tomcat 所需软件包传到/opt目录下2.安装JDK3.设置JDK环…

六、docker安装ngxin部署若依前端

1.第一次安装&#xff0c;不进行挂载数据卷&#xff0c; docker run \ -p 8060:80 \ --name nginx \ --privilegedtrue \ --restartalways \ -d nginx:1.17.82. 将配置信息复制到宿主机本地 # 将容器nginx.conf文件复制到宿主机 docker cp nginx:/etc/nginx/nginx.conf /data…

代码随想录算法训练营第五十天|123.买卖股票的最佳时机III|188.买卖股票的最佳时机IV

LeetCode123.买卖股票的最佳时机III 动态规划五部曲&#xff1a; 1&#xff0c;确定dp数组以及下标的含义&#xff1a; 一天一共就有五个状态&#xff0c; 没有操作 &#xff08;其实我们也可以不设置这个状态&#xff09;第一次持有股票第一次不持有股票第二次持有股票第二…

为什么代码签名需要添加时间戳?

如果您是软件发行商或开发人员&#xff0c;那么您就会知道软件的成功通常取决于下载次数&#xff0c;这部分取决于用户对它的信任程度。因此&#xff0c;为了向用户保证并避免在下载或安装时出现不必要的警告消息&#xff0c;您采取了主动步骤&#xff0c;例如使用受信任的代码…

uniapp和springboot微信小程序开发实战:开发环境准备以及技术选型

文章目录 开发工具STSHBuilder X其他工具技术选型前端开发vueelementUI后端springbootmybatisplusJWT和shiro开发工具 STS STS是开发springboot项目的利器,是Eclipse的一个版本,全称是SpringToolSuite STS下载地址 HBuilder X 下载地址

WWDC2023|苹果iOS 17系统更新:可共享AirTag

苹果正在召开的 WWDC 2023 开发者大会上&#xff0c;宣布推出了 iOS 17 系统。包含一些功能上的更新。 共享 AirTag iOS 17 引入了一项备受期待的 AirTag 功能&#xff0c;即与他人共享 AirTag。自推出以来&#xff0c;AirTag 只能由一个人拥有和使用&#xff0c;但在 iOS 17…