如何从头开始编写LoRA代码,这有一份教程

news2024/11/24 20:52:49

    ChatGPT狂飙160天,世界已经不是之前的样子。
新建了免费的人工智能中文站https://ai.weoknow.com
新建了收费的人工智能中文站https://ai.hzytsoft.cn/

更多资源欢迎关注


作者表示:在各种有效的 LLM 微调方法中,LoRA 仍然是他的首选。

LoRA(Low-Rank Adaptation)作为一种用于微调 LLM(大语言模型)的流行技术,最初由来自微软的研究人员在论文《 LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS 》中提出。不同于其他技术,LoRA 不是调整神经网络的所有参数,而是专注于更新一小部分低秩矩阵,从而大大减少了训练模型所需的计算量。

由于 LoRA 的微调质量与全模型微调相当,很多人将这种方法称之为微调神器。自发布以来,相信很多人都对这项技术感到好奇,想要从头开始编写代码从而更好的理解该研究。以前苦于没有合适的文档说明,现在,教程来了。

这篇教程的作者是知名机器学习与 AI 研究者 Sebastian Raschka,他表示在各种有效的 LLM 微调方法中,LoRA 仍然是自己的首选。为此,Sebastian 专门写了一篇博客《Code LoRA From Scratch》,从头开始构建 LoRA,在他看来,这是一种很好的学习方法。

图片

简单来说,本文通过从头编写代码的方式来介绍低秩自适应(LoRA),实验中 Sebastian 对 DistilBERT 模型进行了微调,并用于分类任务。

LoRA 与传统微调方法的对比结果显示,使用 LoRA 方法在测试准确率上达到了 92.39%,这与仅微调模型最后几层相比(86.22% 的测试准确率)显示了更好的性能。

Sebastian 是如何实现的,我们接着往下看。

从头开始编写 LoRA

用代码的方式表述一个 LoRA 层是这样的:

图片

其中,in_dim 是想要使用 LoRA 修改的层的输入维度,与此对应的 out_dim 是层的输出维度。代码中还添加了一个超参数即缩放因子 alpha,alpha 值越高意味着对模型行为的调整越大,值越低则相反。此外,本文使用随机分布中的较小值来初始化矩阵 A,并用零初始化矩阵 B。

值得一提的是,LoRA 发挥作用的地方通常是神经网络的线性(前馈)层。举例来说,对于一个简单的 PyTorch 模型或具有两个线性层的模块(例如,这可能是 Transformer 块的前馈模块),其前馈(forward)方法可以表述为:

图片

在使用 LoRA 时,通常会将 LoRA 更新添加到这些线性层的输出中,又得到代码如下:

图片

如果你想通过修改现有 PyTorch 模型来实现 LoRA ,一种简单方法是将每个线性层替换为 LinearWithLoRA 层:

图片

以上这些概念总结如下图所示:

图片

为了应用 LoRA,本文将神经网络中现有的线性层替换为结合了原始线性层和 LoRALayer 的 LinearWithLoRA 层。

如何上手使用 LoRA 进行微调

LoRA 可用于 GPT 或图像生成等模型。为了简单说明,本文采用一个用于文本分类的小型 BERT(DistilBERT) 模型来说明。

图片

由于本文只训练新的 LoRA 权重,因而需要将所有可训练参数的 requires_grad 设置为 False 来冻结所有模型参数:

图片

接下来,使用 print (model) 检查一下模型的结构:

图片

由输出可知,该模型由 6 个 transformer 层组成,其中包含线性层:

图片

此外,该模型有两个线性输出层:

图片

通过定义以下赋值函数和循环,可以选择性地为这些线性层启用 LoRA:

图片

使用 print (model) 再次检查模型,以检查其更新的结构:

图片

正如上面看到的,线性层已成功地被 LinearWithLoRA 层取代。

如果使用上面显示的默认超参数来训练模型,则会在 IMDb 电影评论分类数据集上产生以下性能:

  • 训练准确率:92.15%

  • 验证准确率:89.98%

  • 测试准确率:89.44%

在下一节中,本文将这些 LoRA 微调结果与传统微调结果进行了比较。

与传统微调方法的比较

在上一节中,LoRA 在默认设置下获得了 89.44% 的测试准确率,这与传统的微调方法相比如何?

为了进行比较,本文又进行了一项实验,以训练 DistilBERT 模型为例,但在训练期间仅更新最后 2 层。研究者通过冻结所有模型权重,然后解冻两个线性输出层来实现这一点:

图片

只训练最后两层得到的分类性能如下:

  • 训练准确率:86.68%

  • 验证准确率:87.26%

  • 测试准确率:86.22%

结果显示,LoRA 的表现优于传统微调最后两层的方法,但它使用的参数却少了 4 倍。微调所有层需要更新的参数比 LoRA 设置多 450 倍,但测试准确率只提高了 2%。

优化 LoRA 配置

前面讲到的结果都是 LoRA 在默认设置下进行的,超参数如下:

图片

假如用户想要尝试不同的超参数配置,可以使用如下命令:

图片

不过,最佳超参数配置如下:

图片

在这种配置下,得到结果:

  • 验证准确率:92.96%

  • 测试准确率:92.39%

值得注意的是,即使 LoRA 设置中只有一小部分可训练参数(500k VS 66M),但准确率还是略高于通过完全微调获得的准确率。

    ChatGPT狂飙160天,世界已经不是之前的样子。
新建了免费的人工智能中文站https://ai.weoknow.com
新建了收费的人工智能中文站https://ai.hzytsoft.cn/

更多资源欢迎关注


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1580718.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习-随机森林算法预测温度

文章目录 算法简介解决问题获取数据集探索性数据分析查看数据集字段信息查看数据集综合统计结果查看特征值随时间变化趋势 数据预处理处理缺失数据字符列编码数据集分割训练集、验证集、测试集数据集分割 构建模型并训练结果分析与评估进一步优化实际使用经验总结 算法简介 随…

基于遗传优化的SVD水印嵌入提取算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于遗传优化的的SVD水印嵌入提取算法。对比遗传优化前后SVD水印提取性能,并分析不同干扰情况下水印提取效果。 2.测试软件版本以及运行结果展示 MA…

深度学习实践(一)基于Transformer英译汉模型

本文目录 前述一、环境依赖二、数据准备1. 数据加载2. 构建单词表程序解析(1)将列表里每个子列表的所有单词合并到一个新列表(没有子列表)中。(2)Counter()-- 统计迭代对象各元素出现…

【Spring AOP】@Aspect结合案例详解(一): @Pointcut使用@annotation + 五种通知Advice注解(已附源码)

文章目录 前言AOP与Spring AOPAspect简单案例快速入门 一、Pointcutannotation 二、五种通知Advice1. Before前置通知2. After后置通知3. AfterRunning返回通知4. AfterThrowing异常通知5. Around环绕通知 总结 前言 在微服务流行的当下,在使用SpringCloud/Springb…

Mogdb双网卡同步最佳实践

大家都知道Oracle数据库无论是单机还是RAC集群在进行生产部署实施时,我们都会对网卡做冗余考虑,比如使用双网卡,比如public、心跳网络。这样的目的主要是为了安全,避免淡点故障。当然也网卡Bond不仅是可以做主备还可以支持负载均衡…

redis分布式锁+redisson框架

目录 🧂1.锁的类型 🌭2.基于redis实现分布式 🥓3. 基于redisson实现分布式锁 1.锁的类型 1.本地锁:synchronize、lock等,锁在当前进程内,集群部署下依旧存在问题2.分布式锁:redis、zookeeper等…

OLAP介绍

OLAP OLAP介绍 Rollup OLAP(在线分析处理)的上下文中,"Rollup"是一个重要的概念,它指的是在多维数据集中自动地聚合数据到更高的层次或维度的过程。这种操作通常用于快速计算和展示汇总数据,以便于用户进…

包和final.Java

1,包 包就是文件夹。用来管理不同功能的Java类,方便后期代码的维护。 (1)包名的规则是什么? 公司域名反写报的作用,需要全部英文小写,见名知意。com.itheima.domain (2&#xff…

15.队列集

1.简介 在使用队列进行任务之间的“沟通交流”时,一个队列只允许任务间传递的消息为同一种数据类型,如果需要在任务间传递不同数据类型的消息时,那么就可以使用队列集。FreeRTOS提供的队列集功能可以对多个队列进行“监听”,只要…

Redis高级-分布式缓存

分布式缓存 – 基于Redis集群解决单机Redis存在的问题 单机的Redis存在四大问题: 0.目标 1.Redis持久化 Redis有两种持久化方案: RDB持久化AOF持久化 1.1.RDB持久化 RDB全称Redis Database Backup file(Redis数据备份文件)…

QT drawPixmap和drawImage处理图片模糊问题

drawPixmap和drawImage显示图片时,如果图片存在缩放时,会出现模糊现象,例如将一个100x100 的图片显示到30x30的区域,这个时候就会出现模糊。如下: 实际图片: 这个问题就是大图显示成小图造成的像素失真。 当…

FPGA(Verilog)实现按键消抖

实现按键消抖功能: 1.滤除按键按下时的噪声和松开时的噪声信号。 2.获取已消抖的按键按下的标志信号。 3.实现已消抖的按键的连续功能。 Verilog实现 模块端口 key_filter(input wire clk ,input wire rst_n ,input wire key_in , //按下按键时为0output …

[NKCTF2024]-PWN:leak解析(中国剩余定理泄露libc地址,汇编覆盖返回地址)

查看保护 查看ida 先放exp 完整exp: from pwn import* from sympy.ntheory.modular import crt context(log_leveldebug,archamd64)while True:pprocess(./leak)ps[101,103,107,109,113,127]p.sendafter(bsecret\n,bytes(ps))cs[0]*6for i in range(6):cs[i]u32(p…

6.模板初阶(函数模板、类模板、类模板声明与定义分离)

1. 泛型编程 如何实现一个通用的交换函数呢? 使用函数重载虽然可以实现,但是有一下几个不好的地方: 重载的函数仅仅是类型不同,代码复用率比较低,只要有新类型出现时,就需要用户自己增加对应的函数代码的…

线性、逻辑回归算法学习

1、什么是一元线性回归 线性:两个变量之间的关系是一次函数,也是数据与数据之间的关系。 回归:人们在测试事物的时候因为客观条件所限,求的都是测试值,而不是真实值,为了无限接近真实值,无限次的…

HarmonyOS开发实例:【状态管理】

状态管理 ArkUI开发框架提供了多维度的状态管理机制,和UI相关联的数据,不仅可以在组件内使用,还可以在不同组件层级间传递,比如父子组件之间,爷孙组件之间等,也可以是全局范围内的传递,还可以是…

【考研数学】1800还是660还是880?

关于这几本习题册如何选择,肯定是根据他们的不同特点以及我们的需求结合选择,给大家的建议如下: 1800适合初期,可以帮助你熟悉数学公式和基础定义,迅速上手用。刚开始觉得难很正常,存在一个上手的过程&…

VRRP虚拟路由实验(思科)

一,技术简介 VRRP(Virtual Router Redundancy Protocol)是一种网络协议,用于实现路由器冗余,提高网络可靠性和容错能力。VRRP允许多台路由器共享一个虚拟IP地址,其中一台路由器被选为Master,负…

【Erlang】【RabbitMQ】Linux(CentOS7)安装Erlang和RabbitMQ

一、系统环境 查版本对应,CentOS-7,选择Erlang 23.3.4,RabbitMQ 3.9.16 二、操作步骤 安装 Erlang repository curl -s https://packagecloud.io/install/repositories/rabbitmq/erlang/script.rpm.sh | sudo bash安装 Erlang package s…

扫描IP开放端口该脚本用于对特定目标主机进行常见端口扫描(加载端口字典)或者指定端口扫描,判断目标主机开

扫描IP开放端口该脚本用于对特定目标主机进行常见端口扫描(加载端口字典)或者指定端口扫描,判断目标主机开 #/bin/bash #该脚本用于对特定目标主机进行常见端口扫描(加载端口字典)或者指定端口扫描,判断目标主机开放来哪些端口 #用telnet方式 IP$1 #IP119.254.3.28 #获得IP的前…