使用训练工具

news2024/12/25 0:42:41

HuggingFace上提供了很多已经训练好的模型库,如果想针对特定数据集优化,那么就需要二次训练模型,并且HuggingFace也提供了训练工具。

一.准备数据集
1.加载编码工具
加载hfl/rbt3编码工具如下所示:

def load_encode():
    # 1.加载编码工具
    # 第6章/加载tokenizer
    from transformers import AutoTokenizer
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\rbt3'
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    # 第6章/试编码句子
    result = tokenizer.batch_encode_plus(
        ['明月装饰了你的窗子', '你装饰了别人的梦'],
        truncation=True,
    )
    print(result)

输出结果如下所示:

{'input_ids': [[101, 3209, 3299, 6163, 7652, 749, 872, 4638, 4970, 2094, 102], [101, 872, 6163, 7652, 749, 1166, 782, 4638, 3457, 102]], 
'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

2.准备数据集
ChnSentiCorp是谭松波收集整理了一个较大规模的酒店评论语料。7000多条酒店评论数据,5000多条正向评论,2000多条负向评论[3]。

def f1(data):
    # 通过编码工具将文字编码为数据
    from transformers import AutoTokenizer
    from pathlib import Path
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\rbt3'
    tokenizer = AutoTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
    return tokenizer.batch_encode_plus(data['text'], truncation=True)

def f2(data):
    # 过滤太长的句子
    return [len(i) <= 512 for i in data['input_ids']]

def load_dataset_from_disk():
    # 方法1:从HuggingFace加载数据集,然后本地保存
    # from datasets import load_dataset
    # dataset = load_dataset(path='seamew/ChnSentiCorp')
    # print(dataset)
    # dataset.save_to_disk(dataset_dict_path='./data/ChnSentiCorp')

    # 方法2:从本地加载数据集
    from datasets import load_from_disk
    mode_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'
    dataset = load_from_disk(mode_name_or_path)
    # 缩小数据规模,便于测试
    dataset['train'] = dataset['train'].shuffle().select(range(2000))
    dataset['test'] = dataset['test'].shuffle().select(range(100))

    # batched=True表示批量处理
    # batch_size=1000表示每次处理1000个样本
    # num_proc=8表示使用8个线程操作
    # remove_columns=['text']表示移除text列
    dataset = dataset.map(f1, batched=True, batch_size=1000, num_proc=8, remove_columns=['text'])

    return dataset

由于模型对输入文本的长度有限制,不能处理长度大于512词的文本,因此把长度超过512个词的句子过滤掉。过滤前的dataset为:

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 100
    })
})

过滤后的dataset为:

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1982
    })
    validation: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1190
    })
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 99
    })
})

二.定义模型和训练工具
1.加载预训练模型
加载预训练模型代码如下所示:

def load_pretrained_mode():
    """
    加载预训练模型
    """
    from transformers import AutoModelForSequenceClassification
    import torch
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\rbt3'
    model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=2)
    # 统计模型参数量
    print(sum([i.nelement() for i in model.parameters()]) / 10000)

    # 模拟一批数据
    data = {
        'input_ids': torch.ones(4, 10, dtype=torch.long),
        'token_type_ids': torch.ones(4, 10, dtype=torch.long),
        'attention_mask': torch.ones(4, 10, dtype=torch.long),
        'labels': torch.ones(4, dtype=torch.long)
    }
    # 模型试算
    out = model(**data)
    print(out['loss'], out['logits'].shape)

输出结果如下所示:

3847.8338
tensor(0.3911, grad_fn=<NllLossBackward0>) torch.Size([4, 2])

(1)hfl/rbt3模型
由哈尔滨工业大学讯飞联合实验室(HFL)基于中文文本数据训练的BERT模型。
(2)model数据结构

2.定义评价函数
定义评价函数代码如下所示:

def compute_metrics(eval_pred):
    """
    定义评价函数
    """
    from datasets import load_metric
    metric = load_metric('accuracy')
    logits, labels = eval_pred
    logits = logits.argmax(axis=1)
    return metric.compute(predictions=logits, references=labels)
    
if __name__ == '__main__':
    # 定义评价函数
    # 模拟输出
    from transformers.trainer_utils import EvalPrediction
    import numpy as np
    eval_pred = EvalPrediction(
        predictions=np.array([[0, 1], [2, 3], [4, 5], [6, 7]]),
        label_ids=np.array([1, 1, 0, 1]),
    )
    accuracy = compute_metrics(eval_pred)
    print(accuracy)

输出结果如下所示:

{'accuracy': 0.75}

3.定义训练超参数
可通过TrainingArguments对象来封装超参数:

#第6章/定义训练参数
from transformers import TrainingArguments
#定义训练参数
args = TrainingArguments(
#定义临时数据保存路径
output_dir='./output_dir',
#定义测试执行的策略,可取值为no、epoch、steps
evaluation_strategy='steps',
#定义每隔多少个step执行一次测试
eval_steps=30,
#定义模型保存策略,可取值为no、epoch、steps
save_strategy='steps',
#定义每隔多少个step保存一次
save_steps=30,
#定义共训练几个轮次
num_train_epochs=1,
#定义学习率
learning_rate=1e-4,
#加入参数权重衰减,防止过拟合
weight_decay=1e-2,
#定义测试和训练时的批次大小
per_device_eval_batch_size=16,
per_device_train_batch_size=16,
#定义是否要使用GPU训练
no_CUDA=True,
)

4.定义训练器
Trainer参数包括要训练的模型、超参数对象、训练和验证数据集、评价函数,以及数据整理函数。

from transformers import Trainer
from transformers.data.data_collator import DataCollatorWithPadding
#定义训练器
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer),
)

5.数据整理函数介绍
通过DataCollatorWithPadding对象把一个批次中长短不一的句子补充成统一的长度(对句子的尾部补充PAD),长度取决于这个批次中最长的句子有多长,如下所示:

def test_DataCollator(tokenizer, dataset):
    """
    数据整理函数
    """
    from transformers import DataCollatorWithPadding
    # 第6章/测试数据整理函数
    data_collator = DataCollatorWithPadding(tokenizer)
    # 获取一批数据
    data = dataset['train'][:5]
    # 输出这些句子的长度
    for i in data['input_ids']:
        print(len(i))
    # 调用数据整理函数
    data = data_collator(data)
    # 查看整理后的数据
    for k, v in data.items():
        print(k, v.shape)

if __name__ == '__main__':
    from transformers import AutoTokenizer
    from pathlib import Path
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\rbt3'
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=Path(f'{pretrained_model_name_or_path}'))
    # 得到dataset
    dataset = load_dataset_from_disk()
    dataset = dataset.filter(f2, batched=True, batch_size=1000, num_proc=8)
    test_DataCollator(tokenizer, dataset)

结果输出如下所示:

175
136
121
34
160
input_ids torch.Size([5, 175])
token_type_ids torch.Size([5, 175])
attention_mask torch.Size([5, 175])
labels torch.Size([5])

三.训练和测试
1.训练模型
评价和训练模型代码如下所示:

trainer.evaluate() #评价模型
trainer.train()    #训练模型

在output_dir文件夹中可以找到4个文件夹,即checkpoint-30、checkpoint-60、checkpoint-90、checkpoint-120,分别是对应步数保存的检查点,每个文件夹中都有一个PyTorch_model.bin文件,这个文件就是模型的参数。每个文件夹包括文件如下所示:

config.json
optimizer.pt
pytorch_model.bin
rng_state.pth
scheduler.pt
trainer_state.json
training_args.bin

运行结果格式如下所示:

{'eval_loss': 0.48926153779029846, 'eval_accuracy': 0.8181818181818182, 'eval_runtime': 62.1286, 'eval_samples_per_second': 1.593, 'eval_steps_per_second': 0.113, 'epoch': 0.48}

如果模型在训练过程中断了,那么可以从中间检查点继续训练,如下所示:

trainer.train(resume_from_checkpoint='./output_dir/checkpoint-90')

2.模型的保存和加载
模型的保存和加载代码如下所示:

# 手动保存模型参数
trainer.save_model(output_dir='./output_dir/save_model')
# 手动加载模型参数
import torch
model.load_state_dict(torch.load('./output_dir/save_model/PyTorch_model.bin'))

3.使用模型预测
使用模型预测代码如下所示:

# 在模型的评估模式下,模型不再对输入进行梯度计算,并且一些具有随机性的操作(如Dropout)会被固定
model.eval()
for i, data in enumerate(trainer.get_eval_dataloader()):
    data = data.to('cuda')
    out = model(**data)
    out = out['logits'].argmax(dim=1)
    for j in range(8):
        print(tokenizer.decode(data['input_ids'][j], skip_special_tokens=True))
        print('label=', data['labels'][j].item())
        print('predict=', out[j].item())
    break

结果输出如下所示:

酒 店 有 点 偏 , ( 没 有 地 铁 站 ) , 19 : 30 后 就 没 有 shuttle bus 了 。 大 堂 很 小 , 也 没 有 什 么 设 施 。 不 过 , 房 间 很 好 , 也 有 海 景 。
label= 1
predict= 0
哈 哈 哈 哈..... 居 然 还 可 以 继 续 评 论 啊 那 就 给 满 分 了 下 次 去 了 继 续 住 忘 记 说 了, 有 房 内 按 摸 的 服 务 的 可 惜 没 时 间 去 试 了, 下 次 去 还 会 住 的......
label= 1
predict= 0
......
也 许 这 不 算 一 个 很 好 的 理 由, 但 是 我 之 所 以 喜 欢 读 书 而 不 是 看 网 上 的 资 料 什 么 的, 就 是 喜 欢 闻 着 书 香. 这 本 书 可 能 是 印 刷 的 油 墨 不 好 还 是 什 么 原 因, 感 觉 臭 臭 的 不 好 闻. 里 面 是 一 些 关 于 中 式 英 语 的 小 趣 闻, 有 些 小 乐 趣, 但 感 觉 对 于 有 浓 重 中 式 思 维 习 惯 说 英 说 的 人 来 说 才 比 较 有 点 用 处.
label= 0
predict= 0

参考文献:
[1]https://huggingface.co/datasets/seamew/ChnSentiCorp/tree/main
[2]文本数据集的下载与各种操作:https://blog.csdn.net/Wang_Dou_Dou_/article/details/127459760
[3]ChnSentiCorp:https://github.com/SophonPlus/ChineseNlpCorpus/blob/master/datasets/ChnSentiCorp_htl_all/intro.ipynb
[4]https://github.com/ai408/nlp-daily-record/tree/main/20230625_HuggingFace自然语言处理详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/929723.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用C/C++修改I2C默认的SDA和SCL针脚

首先要说明一点&#xff1a;Pico 有两个 I2C&#xff0c;也就是两套 SDA 和 SCL。这点你可以在针脚图中名字看出&#xff0c;比如下图的 Pin 4 和 Pin 5是 I2C1 的&#xff0c;而默认的 Pin 6 和 Pin 7 是 I2C0 的。 默认情况下是只开启了第一个 I2C&#xff0c;也就是只有 I2C…

数据库——缓存数据

文章目录 缓存数据的处理流程是怎样的&#xff1f;为什么要用 Redis/为什么要用缓存&#xff1f; 缓存数据的处理流程是怎样的&#xff1f; 简单来说就是: 如果用户请求的数据在缓存中就直接返回。缓存中不存在的话就看数据库中是否存在。数据库中存在的话就更新缓存中的数据。…

基于云原生网关的流量防护实践

作者&#xff1a;涂鸦 背景 在分布式系统架构中&#xff0c;每个请求都会经过很多层处理&#xff0c;比如从入口网关再到 Web Server 再到服务之间的调用&#xff0c;再到服务访问缓存或 DB 等存储。在下图流量防护体系中&#xff0c;我们通常遵循流量漏斗原则进行流量防护。…

数字孪生赋能工业制造,为制造业带来新机遇与挑战

数字孪生技术是利用模拟仿真技术将实体对象数字化的技术。它基于虚拟现实、人工智能和云计算等技术&#xff0c;能够创建与真实物体相同的数字模型&#xff0c;并通过实时监测和分析手段&#xff0c;为制造企业提供关于该物体的全面数据&#xff0c;从而优化产品开发和生产过程…

《Dive into Deep Learning》

《Dive into Deep Learning》&#xff1a;https://d2l.ai/ Interactive deep learning book with code, math, and discussionsImplemented with PyTorch, NumPy/MXNet, JAX, and TensorFlowAdopted at 500 universities from 70 countries 《动手学深度学习》中文版&#xff1…

dji uav建图导航系列()ROS中创建dji_sdk节点包(一)项目结构

文章目录 1、整体项目结构1.1、 目录launch1.2、文件CMakeLists.txt1.3、文件package.xml1.4、目录include1.4、目录srv在ROS框架下创建一个无人机的节点dji_sdk,实现必需的订阅(控制指令)、发布(无人机里程计)、服务(无人机起飞降落、控制权得很)功能,就能实现一个类似…

C#-集合小例子

目录 背景&#xff1a; 过程: 1.添加1-100数: 2.求和: 3.平均值: 4.代码:​ 总结: 背景&#xff1a; 往集合里面添加100个数&#xff0c;首先得有ArrayList导入命名空间&#xff0c;这个例子分为3步&#xff0c;1.添加1-100个数2.进行1-100之间的总和3.求总和的平均值&…

03.sqlite3学习——数据类型

目录 sqlite3学习——数据类型 SQL语句的功能 SQL语法 SQL命令 SQL数据类型 数字类型 整型 浮点型 定点型decimal 浮点型 VS decimal 日期类型 字符串类型 CHAR和VARCHAR BLOB和TEXT SQLite 数据类型 SQLite 存储类 SQLite 亲和类型(Affinity)及类型名称 Boo…

【微服务】04-Polly实现失败重试和限流熔断

文章目录 1. Polly实现失败重试1.1 Polly组件包1.2 Polly的能力1.3 Polly使用步骤1.4 适合失败重试的场景1.5 最佳实践 2.Polly实现熔断限流避免雪崩效应2.1 策略类型2.2 组合策略 1. Polly实现失败重试 1.1 Polly组件包 PollyPolly.Extensions.HttpMicrosoft.Extensions.Htt…

MaBatis中的分页插件以及特殊字符处理

目录 一、PageHelper介绍 二、PageHelper使用 1. 导入pom依赖 2. Mybatis.cfg.xml 配置拦截器 配置sql映射文件 测试代码 特殊字符处理 2. 使用CDATA 区段 一、PageHelper介绍 PageHelper 是 Mybatis 的一个插件&#xff0c;这里就不扯了&#xff0c;就是为了更加便捷的进…

记录一次“top负1”比赛经历

获奖啦&#xff01; 比赛题目&#xff1a;中文语义病句识别与纠正挑战赛 比赛链接&#xff1a;https://challenge.xfyun.cn/topic/info?typeidentification-and-correction&optionphb“请介绍你们团队” “各位评委老师&#xff0c;我是来自WOT团队的选手AMBT&#xff0…

Python|爬虫和测试|selenium框架的安装和初步使用(一)

前言&#xff1a; Python作为一门胶水语言来说&#xff0c;可以说是十分的优秀&#xff0c;什么事情都可以干&#xff0c;并且在某些领域还能干的非常不错&#xff0c;尤其是在爬虫和测试领域&#xff0c;该语言可以说是没有对手。 这么说的原因是因为如果你要使用爬虫爬取某…

4.网络设计与redis、memcached、nginx组件(二)

系列文章目录 第四章 网络设计与redis、memcached、nginx组件(一) 第五章 网络设计与redis、memcached、nginx组件(二) 文章目录 系列文章目录[TOC](文章目录) 前言一、reactor模型&#xff1f;二、Reactor 开发1.建立连接 三、典型reactor 模型单reactor 模型典型 readisradi…

C++避坑——most vexing parse问题

1."坑"的问题是什么&#xff1f; 先看一段代码&#xff1a; class Functor { public:void operator()(){std::cout << "我是线程的初始函数" << std::endl;} };int main() {std::thread t(Functor());// 强制高速编译器这是一个构造函数!t.j…

Linux:权限

目录 一、shell运行原理 二、权限 1.权限的概念 2.文件访问权限的相关设置方法 三、常见的权限问题 1.目录权限 2.umsk(权限掩码) 3.粘滞位 一、shell运行原理 1.为什么我们不是直接访问操作系统&#xff1f; ”人“不善于直接使用操作系统如果让人直接访问操作系统&a…

lnmp架构-nginx

6.nginx基础配置 证书 重定向&#xff08;80重定向到443&#xff09; 当访问http时 直接到 https 自动索引&#xff1a; 下载方便 Nginx缓存配置 &#xff1a;缓存可以降低网站带宽&#xff0c;加速用户访问 日志轮询 禁用不必要的日志记录 以节省磁盘IO的消耗 监控的信息 监…

基于Android的垃圾分类系统 微信小程序 uniapp

随着网络科技的发展&#xff0c;移动智能终端逐渐走进人们的视线&#xff0c;相关应用越来越广泛&#xff0c;并在人们的日常生活中扮演着越来越重要的角色。因此&#xff0c;关键应用程序的开发成为影响移动智能终端普及的重要因素&#xff0c;设计并开发实用、方便的应用程序…

多态(C++)

多态 一、初识多态概念“登场”1>. 多态的构成条件2>. 虚函数3>. 虚函数重写&#xff08;覆盖&#xff09;4>. 虚函数重写的两个例外1. 协变 一 基类和派生类虚函数返回值类型不同2. 析构函数重写&#xff08;基类和派生类析构函数名不同&#xff09; 小结 二、延伸…

JavaScript函数调用其他函数

在JavaScript中&#xff0c;函数可以调用其他函数。这通常被称为函数组合&#xff0c;它允许你通过将较简单的函数组合在一起来创建更复杂的功能。 例如&#xff1a;还是以之前的水果加工举例&#xff0c;但是现在我们需要输出&#xff0c;这个苹果有几块&#xff0c;橘子有几块…

微信小程序分享后真机参数获取不到和部分参数不能获取问题问题解决

微信小程序的很多API&#xff0c;都是BUG&#xff0c;近期开发小程序就遇到了分享后开发工具可以获取参数&#xff0c;但是真机怎么都拿不到参数的问题 一、真机参数获取不到问题解决 解决方式&#xff1a; 在onLoad(options) 中。 onLoad方法中一定要有options 这个参数。…