使用自动模型

news2025/1/12 0:57:48

本文通过文本分类任务演示了HuggingFace自动模型使用方法,既不需要手动计算loss,也不需要手动定义下游任务模型,通过阅读自动模型实现源码,提高NLP建模能力。

一.任务和数据集介绍
1.任务介绍
前面章节通过手动方式定义下游任务模型,HuggingFace也提供了一些常见的预定义下游任务模型,如下所示:

说明:包括预测下一个词,文本填空,问答任务,文本摘要,文本分类,命名实体识别,翻译等。

2.数据集介绍
本文使用ChnSentiCorp数据集,不清楚的可以参考中文情感分类介绍。一些样例如下所示:

二.准备数据集
1.使用编码工具

def load_encode_tool(pretrained_model_name_or_path):
    """
    加载编码工具
    """
    tokenizer = BertTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
    return tokenizer
if __name__ == '__main__':
    # 测试编码工具
    pretrained_model_name_or_path = r'L:/20230713_HuggingFaceModel/bert-base-chinese'
    tokenizer = load_encode_tool(pretrained_model_name_or_path)
    print(tokenizer)

输出结果如下所示:

BertTokenizer(name_or_path='L:\20230713_HuggingFaceModel\bert-base-chinese', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

2.定义数据集
直接使用HuggingFace数据集对象,如下所示:

def load_dataset_from_disk():
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'
    dataset = load_from_disk(pretrained_model_name_or_path)
    return dataset
if __name__ == '__main__':
    # 加载数据集
    dataset = load_dataset_from_disk()
    print(dataset)

输出结果如下所示:

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 9600
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
})

3.定义计算设备

# 定义计算设备
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
# print(device)

4.定义数据整理函数

def collate_fn(data):
    sents = [i['text'] for i in data]
    labels = [i['label'] for i in data]
    #编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents, # 输入文本
            truncation=True, # 是否截断
            padding=True, # 是否填充
            max_length=512, # 最大长度
            return_tensors='pt') # 返回的类型
    #转移到计算设备
    for k, v in data.items():
        data[k] = v.to(device)
    data['labels'] = torch.LongTensor(labels).to(device)
    return data

5.定义数据集加载器

# 数据集加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'], batch_size=16, collate_fn=collate_fn, shuffle=True, drop_last=True)
print(len(loader))

# 查看数据样例
for i, data in enumerate(loader):
    break
for k, v in data.items():
    print(k, v.shape)

输出结果如下所示:

600
input_ids torch.Size([16, 200])
token_type_ids torch.Size([16, 200])
attention_mask torch.Size([16, 200])
labels torch.Size([16])

三.加载自动模型
使用HuggingFace的AutoModelForSequenceClassification工具类加载自动模型,来实现文本分类任务,代码如下:

# 加载预训练模型
model = AutoModelForSequenceClassification.from_pretrained(Path(f'{pretrained_model_name_or_path}'), num_labels=2)
model.to(device)
print(sum(i.numel() for i in model.parameters()) / 10000)

四.训练和测试
1.训练
需要说明自动模型本身包括loss计算,因此在train()中就不再需要手工计算loss,如下所示:

def train():
    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=5e-4)
    # 定义学习率调节器
    scheduler = get_scheduler(name='linear', # 调节器名称
                              num_warmup_steps=0, # 预热步数
                              num_training_steps=len(loader), # 训练步数
                              optimizer=optimizer) # 优化器
    # 将模型切换到训练模式
    model.train()
    # 按批次遍历训练集中的数据
    for i, data in enumerate(loader):
        # print(i, data)
        # 模型计算
        out = model(**data)
        # 计算1oss并使用梯度下降法优化模型参数
        out['loss'].backward() # 反向传播
        optimizer.step() # 优化器更新
        scheduler.step() # 学习率调节器更新
        optimizer.zero_grad() # 梯度清零
        model.zero_grad() # 梯度清零
        # 输出各项数据的情况,便于观察
        if i % 10 == 0:
            out_result = out['logits'].argmax(dim=1)
            accuracy = (out_result == data.labels).sum().item() / len(data.labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, out['loss'].item(), lr, accuracy)

其中,out数据结构如下所示:

2.测试

def test():
    # 定义测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    # 将下游任务模型切换到运行模式
    model.eval()
    correct = 0
    total = 0
    # 按批次遍历测试集中的数据
    for i, data in enumerate(loader_test):
        # 计算5个批次即可,不需要全部遍历
        if i == 5:
            break
        print(i)
        # 计算
        with torch.no_grad():
            out = model(**data)
        # 统计正确率
        out = out['logits'].argmax(dim=1)
        correct += (out == data.labels).sum().item()
        total += len(data.labels)
    print(correct / total)

五.深入自动模型源代码
1.加载配置文件过程
在执行AutoModelForSequenceClassification.from_pretrained(Path(f'{pretrained_model_name_or_path}'), num_labels=2)时,实际上调用了AutoConfig.from_pretrained(),该函数返回的config对象内容如下所示:

config对象如下所示:

BertConfig {
  "_name_or_path": "L:\\20230713_HuggingFaceModel\\bert-base-chinese",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.32.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}

(1)_name_or_path=bert-base-chinese:模型名字。
(2)attention_probs_DropOut_prob=0.1:注意力层DropOut的比例。
(3)hidden_act=gelu:隐藏层的激活函数。
(4)hidden_DropOut_prob=0.1:隐藏层DropOut的比例。
(5)hidden_size=768:隐藏层神经元的数量。
(6)layer_norm_eps=1e-12:标准化层的eps参数。
(7)max_position_embeddings=512:句子的最大长度。
(8)model_type=bert:模型类型。
(9)num_attention_heads=12:注意力层的头数量。
(10)num_hidden_layers=12:隐藏层层数。
(11)pad_token_id=0:PAD的编号。
(12)pooler_fc_size=768:池化层的神经元数量。
(13)pooler_num_attention_heads=12:池化层的注意力头数。
(14)pooler_num_fc_layers=3:池化层的全连接神经网络层数。
(15)vocab_size=21128:字典的大小。

2.初始化模型过程
BertForSequenceClassification类构造函数包括一个BERT模型和全连接神经网络,基本思路为通过BERT提取特征,通过全连接神经网络进行分类,如下所示:

def __init__(self, config):
    super().__init__(config)
    self.num_labels = config.num_labels
    self.config = config

    self.bert = BertModel(config)
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    self.dropout = nn.Dropout(classifier_dropout)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    # Initialize weights and apply final processing
    self.post_init()

通过forward()函数可证明以上推测,根据问题类型为regression(MSELoss()损失函数)、single_label_classification(CrossEntropyLoss()损失函数)和multi_label_classification(BCEWithLogitsLoss()损失函数)选择损失函数。

参考文献:
[1]HuggingFace自然语言处理详解:基于BERT中文模型的任务实战
[2]https://github.com/ai408/nlp-engineering/blob/main/20230625_HuggingFace自然语言处理详解/第12章:使用自动模型.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/986713.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2020年12月 C/C++(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C++编程(1~8级)全部真题・点这里 第1题:数组指定部分逆序重放 将一个数组中的前k项按逆序重新存放。例如,将数组8,6,5,4,1前3项逆序重放得到5,6,8,4,1。 时间限制:1000 内存限制:65536 输入 输入为两行: 第一行两个整数,以空格分隔,分别为数组元素的个数n(1 < n…

Mybatis传参parameterType为List<Map>

这里分别记录使用过的五种传参方式&#xff1a; 1、在入参只有一个的情况下,Mapper.java中直接传即可 2、而在参数有两三个的情况下,Mapper.java中可以用Param注解来指定入参 程序就知道哪个参对应拼接完SQL的哪个条件字段 并且Mapper.xml中parameterType不用去指定&#xff…

controller接口上带@PreAuthorize的注解如何访问 (postman请求示例)

1. 访问接口 /*** 查询时段列表*/RateLimiter(time 10,count 10)ApiOperation("查询时段列表")PreAuthorize("ss.hasPermi(ls/sy:time:list)")GetMapping("/list")public TableDataInfo list(LsTime lsTime){startPage();List<LsTime> l…

【 Tkinter界面-练习04】 画板作画详细揭示

一、说明 对画布的掌握分三个部分&#xff0c;将图形paint到画布、动画move、鼠标画&#xff1b;本篇将侧重于鼠标画的功能&#xff0c;提起鼠标画实现&#xff0c;将涉及一系列组合操作才能完成&#xff0c;这里将一一加以介绍。 Canvas 小部件具有大量功能&#xff0c;我们不…

这是公司最糟糕的程序员,但是我坚决要留住他!

我在一家著名的软件咨询公司工作&#xff0c;有一天&#xff0c;公司决定对开发人员的个人绩效进行度量。 这个目标很美好&#xff1a;评估个人能力&#xff0c;帮助开发人员成长。 指标经过层层分解&#xff0c;来到我们团队&#xff0c;经过经理的认真讨论&#xff0c;决定不…

云备份客户端——数据管理模块

数据管理模块设计之前&#xff0c;我们需要先明确该模块的信息是用来做什么的。根据上文分析该模块信息主要用于判断一个文件是否需要备份&#xff0c;判断条件有两个&#xff1a;1.新增文件 2.被修改过的文件 新增文件好判断&#xff0c;由于我们获得新文件后是先上传文件&…

有效利用云测试的关键要素是什么

云测试是一种基于云计算平台的软件测试方法&#xff0c;它将测试环境和资源部署在云端&#xff0c;通过网络连接来执行测试任务。云测试提供了弹性的计算能力和资源管理&#xff0c;可以根据需求快速扩展和缩减测试环境&#xff0c;使测试过程更加灵活和高效。那么&#xff0c;…

如何在国内安装Bitdefender

我一直有关注国外的antivirus的情况&#xff0c;之前一直用ESET&#xff0c;但是最近一直关注到 Bitdefender 可以和卡巴斯基旗鼓相当&#xff0c;于是抱着试试看的精神&#xff0c;在win10和win11安装了一遍。外国软件大都服务部署在AWS&#xff0c;但是我们这儿的运营商和某种…

使用极域电子教室控制学员机开机问题

遇到问题&#xff1a; 昨天晚上试了一下从网上下载的“极域电子教室”软件&#xff0c;首先保证教师机和学员机器在同一局域网下&#xff0c;然后我发现&#xff1a;教师机可以控制学员机 关机、重启&#xff0c;但是不能控制学员机 开机。 解决办法&#xff1a; 按下电脑开机…

磁盘分析 wiztree[win32] baobab[linux]

磁盘分析 wiztree[win32] && baobab[linux] wiztree[win32]baobab 又叫 Disk Usage Analyzer[linux]安装使用 参考 wiztree[win32] baobab 又叫 Disk Usage Analyzer[linux] baobab 又叫 Disk Usage Analyzer&#xff0c;是 Ubuntu 系统默认自带的磁盘分析工具&#x…

原生js之dom添加表单验证

第一种,在form表单中加入onsubmit事件,进入事件后,可以通过dom.forms[父formname][子formname].value,然后测试这个别名是否为空,在这个判断语句中即可放入想要的表单验证 第二种,在input中加入required,这个是浏览器默认的校验,如果说input中加入required,则默认它生效. <…

【Arduino27】DHT11温湿度传感器模拟值实验

硬件准备 DHT11温湿度&#xff1a;1个 面包板&#xff1a;1个 杜邦线&#xff1a;3根 硬件连线 VDD引脚接 5V 电源 DATE引脚接 4号 接口 GND引脚接 GND 接口 软件程序 #include<DHT.h>#define DHT11_pin 4 //温湿度传感器引脚DHT dht(DHT11_pin,DHT11);float tem…

GO语言实战之接口实现与方法集

写在前面 嗯&#xff0c;学习GO&#xff0c;所以有了这篇文章博文内容为《GO语言实战》读书笔记之一主要涉及知识 接口是什么方法集(值接收和指针接收)多态 傍晚时分&#xff0c;你坐在屋檐下&#xff0c;看着天慢慢地黑下去&#xff0c;心里寂寞而凄凉&#xff0c;感到自己的…

FastDFS介绍

文章目录 一、什么是FastDFS二、FastDFS的架构2.1 跟踪服务器&#xff08;tracker server&#xff09;2.2 存储服务器&#xff08;storage server&#xff09;2.3 客户端&#xff08;client&#xff09; 三、 FastDFS功能逻辑分析3.1 upload file&#xff08;上传文件&#xff…

shell入门运算符操作、条件判断

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a; 小刘主页 ♥️努力不一定有回报&#xff0c;但一定会有收获加油&#xff01;一起努力&#xff0c;共赴美好人生&#xff01; ♥️学习两年总结出的运维经验&#xff0c;以及思科模拟器全套网络实验教程。专栏&#xf…

分类算法系列⑤:决策树

目录 1、认识决策树 2、决策树的概念 3、决策树分类原理 基本原理 数学公式 4、信息熵的作用 5、决策树的划分依据之一&#xff1a;信息增益 5.1、定义与公式 5.2、⭐手动计算案例 5.3、log值逼近 6、决策树的三种算法实现 7、API 8、⭐两个代码案例 8.1、决策树…

springboot使用切面记录接口访问日志

前言 当我们开发和维护一个复杂的应用程序时&#xff0c;了解应用程序的运行情况变得至关重要。特别是在生产环境中&#xff0c;我们需要追踪应用程序的各个方面&#xff0c;以确保它正常运行并能够及时发现潜在的问题。其中之一关键的方面是记录应用程序的接口访问日志。 Sp…

Linux系统离线安装RabbitMQ

安装rabbitmq 1、下载安装包 首先进入官网进行安装包的下载&#xff0c;在下载时一定要注意erlong版本和rabbitmq-server版本匹配 rabbitmq版本对应关系&#xff1a;传送门 Erlong下载地址:传送门 rabbitmq-server下载地址:传送门 socat 不同版本 centos7:传送门 cent…

理解项目开发(寺庙小程序)

转载自&#xff1a;历经一年&#xff0c;开发一个寺庙小程序&#xff01; (qq.com) 破防了&#xff01;为方丈开发一款纪念小程序&#xff01; (qq.com) 下面内容转载自&#xff1a;程序员5K为青岛啤酒节开发个点餐系统&#xff01; (qq.com) 看一个人如何完成一个项目的开发…

SpringBoot集成Swagger的使用

Swagger是一个规范和完整的框架,用于生成、描述、调用和可视化RESTful风格的Web服务。目标是使客户端和文件系统作为服务器以同样的速度来更新文件的方法,参数和模型紧密集成到服务器。 Swagger能够在线自动生成 RESTFul接口的文档&#xff0c;同时具备测试接口的功能。 简单…