使用LORA微调RoBERTa

news2025/1/16 5:36:26

模型微调是指在一个已经训练好的模型的基础上,针对特定任务或者特定数据集进行再次训练以提高性能的过程。微调可以在使其适应特定任务时产生显着的结果。

RoBERTa(Robustly optimized BERT approach)是由Facebook AI提出的一种基于Transformer架构的预训练语言模型。它是对Google提出的BERT(Bidirectional Encoder Representations from Transformers)模型的改进和优化。

“Low-Rank Adaptation”(低秩自适应)是一种用于模型微调或迁移学习的技术。一般来说我们只是使用LORA来微调大语言模型,但是其实只要是使用了Transformers块的模型,LORA都可以进行微调,本文将介绍如何利用🤗PEFT库,使用LORA提高微调过程的效率。

LORA可以大大减少了可训练参数的数量,节省了训练时间、存储和计算成本,并且可以与其他模型自适应技术(如前缀调优)一起使用,以进一步增强模型。

但是,LORA会引入额外的超参数调优层(特定于LORA的秩、alpha等)。并且在某些情况下,性能不如完全微调的模型最优,这个需要根据不同的需求来进行测试。

首先我们安装需要的包:

 !pip install transformers datasets evaluate accelerate peft

数据预处理

 import torch
 from transformers import RobertaModel, RobertaTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
 from peft import LoraConfig, get_peft_model
 from datasets import load_dataset
 
 
 
 peft_model_name = 'roberta-base-peft'
 modified_base = 'roberta-base-modified'
 base_model = 'roberta-base'
 
 dataset = load_dataset('ag_news')
 tokenizer = RobertaTokenizer.from_pretrained(base_model)
 
 def preprocess(examples):
     tokenized = tokenizer(examples['text'], truncation=True, padding=True)
     return tokenized
 
 tokenized_dataset = dataset.map(preprocess, batched=True,  remove_columns=["text"])
 train_dataset=tokenized_dataset['train']
 eval_dataset=tokenized_dataset['test'].shard(num_shards=2, index=0)
 test_dataset=tokenized_dataset['test'].shard(num_shards=2, index=1)
 
 
 # Extract the number of classess and their names
 num_labels = dataset['train'].features['label'].num_classes
 class_names = dataset["train"].features["label"].names
 print(f"number of labels: {num_labels}")
 print(f"the labels: {class_names}")
 
 # Create an id2label mapping
 # We will need this for our classifier.
 id2label = {i: label for i, label in enumerate(class_names)}
 
 data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

训练

我们训练两个模型,一个使用LORA,另一个使用完整的微调流程。这里可以看到LORA的训练时间和训练参数的数量能减少多少

以下是使用完整微调

 
 training_args = TrainingArguments(
     output_dir='./results',
     evaluation_strategy='steps',
     learning_rate=5e-5,
     num_train_epochs=1,
     per_device_train_batch_size=16,
 )
 

然后进行训练:

 def get_trainer(model):
       return  Trainer(
           model=model,
           args=training_args,
           train_dataset=train_dataset,
           eval_dataset=eval_dataset,
           data_collator=data_collator,
       )
 full_finetuning_trainer = get_trainer(
     AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label),
 )
 
 full_finetuning_trainer.train()

下面看看PEFT的LORA

 model = AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label)
 
 peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
 peft_model = get_peft_model(model, peft_config)
 
 print('PEFT Model')
 peft_model.print_trainable_parameters()
 
 peft_lora_finetuning_trainer = get_trainer(peft_model)
 
 peft_lora_finetuning_trainer.train()
 peft_lora_finetuning_trainer.evaluate()

可以看到

模型参数总计:125,537,288,而LORA模型的训练参数为:888,580,我们只需要用LORA训练~0.70%的参数!这会大大减少内存的占用和训练时间。

在训练完成后,我们保存模型:

 tokenizer.save_pretrained(modified_base)
 peft_model.save_pretrained(peft_model_name)

最后测试我们的模型

 from peft import AutoPeftModelForSequenceClassification
 from transformers import AutoTokenizer
 
 # LOAD the Saved PEFT model
 inference_model = AutoPeftModelForSequenceClassification.from_pretrained(peft_model_name, id2label=id2label)
 tokenizer = AutoTokenizer.from_pretrained(modified_base)
 
 
 def classify(text):
   inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
   output = inference_model(**inputs)
 
   prediction = output.logits.argmax(dim=-1).item()
 
   print(f'\n Class: {prediction}, Label: {id2label[prediction]}, Text: {text}')
   # return id2label[prediction]
 
 classify( "Kederis proclaims innocence Olympic champion Kostas Kederis today left hospital ahead of his date with IOC inquisitors claiming his ...")

 classify( "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.")

模型评估

我们还需要对PEFT模型的性能与完全微调的模型的性能进行对比,看看这种方式有没有性能的损失

 from torch.utils.data import DataLoader
 import evaluate
 from tqdm import tqdm
 
 metric = evaluate.load('accuracy')
 
 def evaluate_model(inference_model, dataset):
 
     eval_dataloader = DataLoader(dataset.rename_column("label", "labels"), batch_size=8, collate_fn=data_collator)
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     inference_model.to(device)
     inference_model.eval()
     for step, batch in enumerate(tqdm(eval_dataloader)):
         batch.to(device)
         with torch.no_grad():
             outputs = inference_model(**batch)
         predictions = outputs.logits.argmax(dim=-1)
         predictions, references = predictions, batch["labels"]
         metric.add_batch(
             predictions=predictions,
             references=references,
         )
 
     eval_metric = metric.compute()
     print(eval_metric)
     

首先是没有进行微调的模型,也就是原始模型

 evaluate_model(AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label), test_dataset)

accuracy: 0.24868421052631579‘

下面是LORA微调模型

 evaluate_model(inference_model, test_dataset)

accuracy: 0.9278947368421052

最后是完全微调的模型:

 evaluate_model(full_finetuning_trainer.model, test_dataset)

accuracy: 0.9460526315789474

总结

我们使用PEFT对RoBERTa模型进行了微调和评估,可以看到使用LORA进行微调可以大大减少训练的参数和时间,但是在准确性方面还是要比完整的微调要稍稍下降。

本文代码:

https://avoid.overfit.cn/post/26e401b70f9840dab185a6a83aac06b0

作者:Achilles Moraites

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1448109.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【EAI 020】Diffusion Policy: Visuomotor Policy Learning via Action Diffusion

论文标题:Diffusion Policy: Visuomotor Policy Learning via Action Diffusion 论文作者:Cheng Chi, Siyuan Feng, Yilun Du, Zhenjia Xu, Eric Cousineau, Benjamin Burchfiel, Shuran Song 作者单位:Columbia University, Toyota Research…

C#中implicit和explicit

理解: 使用等号代替构造函数调用的效果以类似重载操作符的形式定义用于类型转换的函数前者类型转换时候直接写等号赋值语法,后者要额外加目标类型的强制转换stirng str -> object o -> int a 可以 int a (int)(str as object)转换通过编译,但没有转换逻辑所以运行会报错…

HCIA-HarmonyOS设备开发认证V2.0-轻量系统内核基础-事件event

目录 一、事件基本概念二、事件运行机制三、事件开发流程四、事件使用说明五、事件接口坚持就有收获 一、事件基本概念 事件是一种实现任务间通信的机制,可用于实现任务间的同步,但事件通信只能是事件类型的通信,无数据传输。一个任务可以等…

LeetCode、452. 用最少数量的箭引爆气球【中等,贪心,区间问题】

文章目录 前言LeetCode、452. 用最少数量的箭引爆气球【中等,贪心,区间问题】题目链接与分类思路贪心,连续区间数量问题 资料获取 前言 博主介绍:✌目前全网粉丝2W,csdn博客专家、Java领域优质创作者,博客…

带你掌握getchar与putchar的基本用法

个人主页(找往期文章包括但不限于本期文章中不懂的知识点):我要学编程(ಥ_ಥ)-CSDN博客 目录 getcahr putchar getchar 与 putchar 的配合使用 getchar相较于scanf的优缺点 putchar相较于printf的优缺点 getcahr 函数原型&#xff1a…

【教程】MySQL数据库学习笔记(二)——数据类型(持续更新)

写在前面: 如果文章对你有帮助,记得点赞关注加收藏一波,利于以后需要的时候复习,多谢支持! 【MySQL数据库学习】系列文章 第一章 《认识与环境搭建》 第二章 《数据类型》 文章目录 【MySQL数据库学习】系列文章一、整…

DFM-无监督图像匹配

DFM:A Performance Baseline for Deep Feature Matching(深度特征匹配的性能基准) 2021.06.14 摘要 提出了一种新的图像匹配方法,利用现成的深度神经网络提取的学习特征来获得良好的图像匹配效果。该方法使用预训练的VGG结构作为…

starknet之 class_hash

文章目录 问题背景什么是Class Hash问题背景 部署合约报错:ReferenceError: Buffer is not defined 什么是Class Hash 官方: https://book.starknet.io/ch04-03-01-deploy-standard-account.html?highlight=class%20hash#finding-the-class-hash 要部署智能合约,您需要在…

【原创 附源码】Flutter集成Apple支付详细流程(附源码)

最近有时间,特意整理了一下之前使用过的Flutter平台的海外支付,附源码及demo可供参考 这篇文章只记录Apple支付的详细流程,其他相关Flutter文章链接如下: 【原创 附源码】Flutter集成谷歌支付详细流程(附源码) 【原创 附源码】F…

PR:熟悉PR工作环境

新建项目 设置自己的页面布局 首选项

【JavaEE】_JavaScript基础语法

目录 1. JavaScript概述 1.1 JavaScript简介 1.2 HTML、CSS、JavaScript的关系 1.3 JavaScrip的组成 2. JavaScript的书写形式 2.1 内嵌式 2.2 行内式 2.3 外部式 3. 输出 3.1 alert 3.2 console.log 4. 变量的使用 4.1 创建变量 4.1.1 使用var 4.1.2 使用let …

java中事务的使用

文章目录 前言一、同一张表1.业务代码2.测试代码3.测试结果 二、不同表1.业务代码2.测试代码3.测试结果 总结 前言 本文将介绍在springboot中使用Transactional注解来完成对数据库事务的操作,保证数据一致性。 一、同一张表 1.业务代码 Controller Controller p…

停止内耗,做有用的事

很多读者朋友跟我交流的时候,都以为我有存稿,于是听到我说每周四现写的时候都很惊讶。其实没什么好惊讶的,每周四我都会把自己关在书房里一整天,断掉一切电话、微信、邮件,从中午写到晚上,直到写完为止。 这…

算法学习——LeetCode力扣回溯篇1

算法学习——LeetCode力扣回溯篇1 77. 组合 77. 组合 - 力扣(LeetCode) 描述 任何顺序 返回答案。 示例 示例 1: 输入:n 4, k 2 输出: [ [2,4], [3,4], [2,3], [1,2], [1,3], [1,4], ] 示例 2: 输…

springboot743二手交易平台

springboot743二手交易平台 获取源码——》公主号:计算机专业毕设大全

《Java 简易速速上手小册》第8章:Java 性能优化(2024 最新版)

文章目录 8.1 性能评估工具 - 你的性能探测仪8.1.1 基础知识8.1.2 重点案例:使用 VisualVM 监控应用性能8.1.3 拓展案例 1:使用 JProfiler 分析内存泄漏8.1.4 拓展案例 2:使用 Gatling 进行 Web 应用压力测试 8.2 JVM 调优 - 魔法引擎的调校8…

第四篇【传奇开心果微博系列】Python微项目技术点案例示例:美女颜值判官

传奇开心果微博系列 系列微博目录Python微项目技术点案例示例系列 微博目录一、微项目目标二、雏形示例代码三、扩展思路四、添加不同类型的美女示例代码五、增加难度等级示例代码六、添加特殊道具示例代码七、设计关卡系统示例代码八、添加音效和背景音乐示例代码九、多人游戏…

【解决】idea控制台不输出trace/debug日志

idea控制台不输出trace日志 问题原因解决 问题 idea控制台不输出trace日志。 pom文件&#xff1a; <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-logging</artifactId></dependency>输出lo…

23种计模式之Python/Go实现

目录 设计模式what?why?设计模式&#xff1a;设计模式也衍生出了很多的新的种类&#xff0c;不局限于这23种创建类设计模式&#xff08;5种&#xff09;结构类设计模式&#xff08;7种&#xff09;行为类设计模式&#xff08;11种&#xff09; 六大设计原则开闭原则里氏替换原…

P3612 [USACO17JAN] Secret Cow Code S题解

题目 奶牛正在试验秘密代码&#xff0c;并设计了一种方法来创建一个无限长的字符串作为其代码的一部分使用。 给定一个字符串&#xff0c;让后面的字符旋转一次&#xff08;每一次正确的旋转&#xff0c;最后一个字符都会成为新的第一个字符&#xff09;。也就是说&#xff0…