入门微调预训练Transformer模型

news2024/11/15 21:22:26

大家好,HuggingFace 为众多开源的自然语言处理(NLP)模型提供了强大的支持平台,让这些模型能够通过训练和微调来更好地服务于各种特定的应用场景。在大型语言模型(LLM)迅猛发展的今天,HuggingFace 提供的核心工具,特别是 Trainer 类,极大地优化了 NLP 模型的训练过程,开发者得以更加高效地实现模型定制和优化。

HuggingFace 的 Trainer 类是为 Transformer 模型量身打造的,不仅优化了模型的交互体验,还与 Datasets 和 Evaluate 等库实现了紧密集成,支持更高级的分布式训练,并能无缝对接 Amazon SageMaker 等基础设施服务。通过这种方式,可以更加便捷地进行模型训练和部署。

本文将通过一个实例,展示如何利用 HuggingFace 的 Trainer 类在本地环境中对 BERT 模型进行微调,以处理文本分类任务。并且重点介绍如何使用 HuggingFace 模型中心的预训练模型,而不是深入机器学习的理论基础。

 1.设置

示例将在 SageMaker Studio(https://aws.amazon.com/cn/sagemaker/studio/) 环境下进行操作,利用 ml.g4dn.12xlarge 实例搭载的 conda_python3 内核来完成任务。需要提醒的是,可以选择使用更小型的实例,但这可能会影响训练速度,具体取决于可用的 CPU/工作进程的数量。

使用 HuggingFace 数据集库下载数据集。

import datasets
from datasets import load_dataset

这里指定了训练数据集和评估数据集,会在训练循环中进行使用。

train_dataset = load_dataset("imdb", split="train")
test_dataset = load_dataset("imdb", split="test")
test_subset = test_dataset.select(range(100)) # 取数据的一个子集进行评估

对于任何文本数据,必须指定一个分词器,将数据预处理成模型可以理解的格式。在这种情况下,这里指定了我们使用的 BERT 模型的 HuggingFace 模型中心 ID。

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

# 分词文本数据
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

然后使用内置的 map 函数处理我们的训练和评估数据集。

tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_subset.map(tokenize_function, batched=True)

图片

预处理后的数据

2.微调 BERT

数据准备就绪后,利用先前选定的模型ID来加载BERT模型。需要注意的是,针对文本分类任务,还定义了标签的总数。在此案例中设定了两个标签,分别用0和1来表示,0代表负面,1代表正面。

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

接下来在训练循环中,需要定义一个TrainingArguments对象。在这个对象中,可以设置训练过程中的各种参数,比如训练周期的数量、分布式训练的策略等。

from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch", num_train_epochs=1)

对于评估,使用 Evaluate 库内置的评估函数。

import numpy as np
import evaluate
metric = evaluate.load("accuracy")

# 评估函数
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

然后将 TrainingArguments、分词数据集和评估指标函数传递给 Trainer 对象。可以使用 train 方法启动训练运行,这将需要大约 10-15 分钟的时间,具体取决于现有硬件。

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test, # 使用测试作为评估
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)
trainer.train()

图片

训练完成

对于推理,可以直接使用微调后的 trainer 对象,并在用于评估的分词测试数据集上进行预测:

trainer.predict(tokenized_test)

图片

输出

在更为实际的应用场景中,可以使用 trainer 对象将模型工件保存到本地目录中。

trainer.save_model("./custom_model")

图片

模型工件

然后可以加载这些模型工件,指定训练的模型类型,并在单个数据点上进行推理。

loaded_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path="custom_model/")

# 样本推理
encoding = tokenizer("I am super delighted", return_tensors="pt")
res = loaded_model(**encoding)
predicted_label_classes = res.logits.argmax(-1)
predicted_label_classes

图片

正面分类

在现实应用场景中,可以将训练好的模型工件部署到像 Amazon SageMaker 这样的服务堆栈上。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1575269.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

100 个网络基础知识,看完成半个网络高手

100 个网络基础知识,看完成半个网络高手。 1)什么是链接? 链接是指两个设备之间的连接。它包括用于一个设备能够与另一个设备通信的电缆类型和协议。 2)OSI 参考模型的层次是什么? 有 7 个 OSI 层:物理层,数据链路层,网络层&…

Rustdesk二次编译,新集成AI功能开源Gpt小程序为远程协助助力,全网首发

环境: Rustdesk1.1.9 sciter版 问题描述: Rustdesk二次编译,新集成AI功能开源Gpt小程序为远程协助助力,全网首发 解决方案: Rustdesk二次编译,新集成开源AI功能Gpt小程序,为远程协助助力&#xff0c…

何为HTTP状态码?一文清楚基本概念。

在客户端与服务器之间的信息传输过程中,我们可以将其比喻为客户与快递员之间的包裹传递。那么服务器是如何通知客户端,操作是成功还是失败?或者有其他的一些情况呢?(就像客户可以查询快递的状态) 而这背后…

C++分析程序各模块耗时-perf火焰图

C分析程序各模块耗时-perf火焰图 1. 简介2. 安装3. 测试示例4. 从火焰图可以获得的信息5. 生成火焰图常见问题 Reference: Perf Wiki【性能】perf 火焰图分析软件性能瓶颈【火焰图🔥】Linux C/C性能优化分析工具Perf使用教程 perf: Linux profiling with perform…

【Python】还在用print进行调试,你Out了!!!

1. 引言 Python 中最常用的函数是什么?像在大多数编程语言中,print() 函数是最常用的。我相信大多数开发者都会像我一样,在开发过程中多次使用它将信息进行打印。 当然,没有其他方法可以完全取代print()函数。不过,当…

鱼塘钓鱼(c++实现)

题目 有 N 个鱼塘排成一排,每个鱼塘中有一定数量的鱼,例如:N5 时,如下表: 即:在第 1 个鱼塘中钓鱼第 1 分钟内可钓到 10 条鱼,第 2 分钟内只能钓到 8 条鱼,……,第 5 分…

【JavaEE】_Spring MVC项目获取Cookie

目录 1. Cookie与Session基础知识 1.1 Cookie与Session的区别 2. 使用servlet原生方法获取Cookie 2.2 关于λ表达式遍历法的空指针问题 2.3 Cookie的伪造 3. 使用Spring注解获取Cookie 3.1 获取单个Cookie 3.2 获取多个Cookie 1. Cookie与Session基础知识 在本专栏HTT…

【随笔】Git 高级篇 -- 整理提交记录(上)cherry-pick(十五)

💌 所属专栏:【Git】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! 💖 欢迎大…

docker-compose安装adguard给局域网提供dns加速服务

启动配置 docker-compose.yaml配置文件 version: 3.3 services:adguard:image: adguard/adguardhome:latestcontainer_name: adguardrestart: unless-stoppedvolumes:- ./workdir:/opt/adguardhome/work- ./confdir:/opt/adguardhome/confports:- 53:53/tcp- 53:53/udp- 81:8…

【fdisk 相关分区命令记录】

目的 记录下新磁盘下刚刚分配的系统(安装系统后未操作或者新扩容的)的分区格式化及挂载,比如这里运维分配了100G 步骤: 1.查看新硬盘 lsblk -f查看,sdb就是新分配的硬盘,无任何相关的属性信息 2、分区明细查看 fd…

鸿蒙内核源码分析 (双向链表篇) | 谁是内核最重要结构体

双向链表是什么&#xff1f; 谁是鸿蒙内核最重要的结构体 &#xff1f; 一定是: LOS_DL_LIST(双向链表)&#xff0c; 它长这样。 typedef struct LOS_DL_LIST {struct LOS_DL_LIST *pstPrev; /**< Current nodes pointer to the previous node | 前驱节点(左手)*/struct L…

ZYNQ学习Linux 基础外设的使用

基本都是摘抄正点原子的文章&#xff1a;《领航者 ZYNQ 之嵌入式Linux 开发指南 V3.2.pdf》&#xff0c;因初次学习&#xff0c;仅作学习摘录之用&#xff0c;有不懂之处后续会继续更新~ 工程的创建参考&#xff1a;《ZYNQ学习之Petalinux 设计流程实战》 一、GPIO 之 LED 的使…

Open CASCADE学习|旋转变换

物体在三维空间中的旋转变换操作通常可以通过三种不同的方式来表示&#xff1a;矩阵&#xff08;Matrix&#xff09;、欧拉角&#xff08;Euler Angles&#xff09;和四元数&#xff08;Quaternion&#xff09;。下面详细解释这三种表示方法。 矩阵&#xff08;Matrix&#xf…

SpringCloud学习(10)-SpringCloudAlibaba-Nacos服务注册、配置中心

Spring Cloud Alibaba 参考文档 Spring Cloud Alibaba 参考文档 nacos下载Nacos 快速开始 直接进入bin包 运行cmd命令&#xff1a;startup.cmd -m standalone 运行成功后通过http://localhost:8848/nacos进入nacos可视化页面&#xff0c;账号密码默认都是nacos Nacos服务注…

全景化工厂虚拟场景VR在线编辑突破传统束缚

数字化时代来临&#xff0c;让很多行业发生了天翻地覆的变化&#xff0c;更多人和企业接纳和亲近VR/AI/3D等技术&#xff0c;虚拟仿真VR内容编辑器系统不仅在畜牧培训领域大放异彩&#xff0c;更在其他多个行业领域展现出广泛的应用前景。 相比传统的VR虚拟现实应用程序开发依赖…

如何使用开源情报跟踪一个人?在线访问网站以及使用方法介绍

如何使用开源情报跟踪一个人&#xff1f;在线访问网站以及使用方法介绍。 开源情报&#xff08;OSINT&#xff09;是一门关于收集和分析公开可用信息的独特技艺&#xff0c;它致力于构建个人或团体的详尽档案。 这一过程中&#xff0c;信息搜集者会利用多元化的信息源&#xff…

如何使用 langchain 与 openAI 连接

上一篇写了如何安装 langchain https://www.cnblogs.com/hailexuexi/p/18087602 这里主要说一个 langchain的使用 创建一个目录 langchain &#xff0c;在这个目录下创建两个文件 main.py 这段python代码&#xff0c;用到了openAI&#xff0c;需要openAI及FQ。这里只做…

【NLP】隐马尔可夫(HMM)与条件随机场(CRF)简介

一. HMM 隐马尔可夫模型&#xff08;Hidden Markov Model, HMM&#xff09;是一种用于处理含有隐藏状态的序列数据的统计学习模型。通过建模隐藏状态之间的转移关系以及隐藏状态与观测数据的生成关系&#xff0c;HMM能够在仅观察到部分信息的情况下进行状态推理、概率计算、序…

Spring Security——06,授权_封装权限信息

授权_封装权限信息 一、权限系统的作用二、授权基本流程三、限制访问资源所需权限四、封装权限信息4.1 权限信息封装到LoginUser4.2 LoginUser 添加权限4.3 过滤器封装权限信息 五、断点测试5.1 有权限的访问5.2 没有权限的访问 一键三连有没有捏~~ 一、权限系统的作用 例如一…

数据结构(3)----栈和队列

目录 一.栈 1.栈的基本概念 2.栈的基本操作 3.顺序栈的实现 •顺序栈的定义 •顺序栈的初始化 •进栈操作 •出栈操作 •读栈顶元素操作 •若使用另一种方式: 4.链栈的实现 •链栈的进栈操作 •链栈的出栈操作 •读栈顶元素 二.队列 1.队列的基本概念 2.队列的基…