LoRA微调方法详解

news2025/1/21 17:54:01

本文要介绍的是大模型的微调训练方法之一----LoRA。

0 背景

现在大模型非常火爆,大家都在想方设法应用大模型。 当前很多大模型虽说可以zero-shot直接使用, 但是在具体应用上一般还是微调一下效果更好, 也就是常说的finetune。 在小模型时代, finetune不是个问题。 但大模型时代, finetune是个大问题。 这是因为现在的大模型参数动辄10B起, 训练的代价非常高昂,即使是finetune也对计算资源有很高要求(finetune只是训练的步数少, 对显存等计算资源的占用并没有少)。 没个上百G的显存是玩不动的, 这对普通人的门槛实在太高了。

那么高效的finetune方式就非常必要了。LoRA就是高效finefune方法的一种。

1 LoRA原理

LoRA论文: LoRA: Low-Rank Adaptation of Large Language Models

在这里插入图片描述
LoRA的原理非常简单, 先上一张图, 其实从图上已经能清楚地看到大致的原理的。 通俗地讲, 它的原理是这样的:大模型都是过参数化的, 当用于特定任务时, 其实只有一小部分参数起主要作用。 也就是参数矩阵维度很高, 但可以用低维矩阵分解近似。其实这个思想与矩阵特征向量, 主成分分析, 压缩感知等有异曲同工之妙。

具体做法是, 在网络中增加一个旁路结构,旁路是A和B两个矩阵相乘。 A矩阵的维度是dxr, B 矩阵的维度是rxd, 其中r<<d, 一般r取1,2,4,8就够了。那么这个旁路的参数量将远远小于原来网络的参数W。LoRA训练时, 我们冻结原来网络的参数W, 只训练旁路参数AB。 由于AB的参数量远远小于W, 那么训练时需要的显存开销就大约等于推理时的开销。 对采用Adam优化器来说, 需要的显存就大约相当于全参数finetune的1/3, 极大地减小了训练的代价。

论文中作者的实验也证明了这一点。 在GPT-3 175B的finetune中, 采用LoRA微调显存的消耗从1.2TB 降低到了350GB, 大约是三分之一

其实采用这种旁路相加的方式, 与ResNet的跳连方式也有异曲同工之妙。 原网络的参数不变, 在旁路上做些微小改变, 适应特定新任务。 这样就可以让网络基本保持原来的能力, 在特定任务上更精进了一步。

值得注意的是, LoRA微调并没有改变原有的预训练参数, 只是针对特定任务微调出了新的少量参数, 新的这些参数要与原有的预训练参数配合使用(实际使用时, 都是把旁路的参数和原来的参数直接合并, 也就是参数相加, 这样就完全不会增加推理时间)。这是非常方便的, 针对不同的任务, 都可以训练出自己的LoRA参数, 然后与原本的预训练参数结合, 做成插件式的应用。 这就是最近大火的SD + LoRA。全参数微调一般没这个条件, 但LoRA微调还是可以的。 目前Civitai上有上万LoRA的模型, 并且还在迅速增加。

2 代码详解

LoRA代码: https://github.com/microsoft/LoRA

LoRA原理很简单, 代码实现也不复杂。 简单地说,在模型实现上, 要在特定的模块上加一个旁路, 这个旁路就是两个矩阵相乘的形式。这些特定的模块理论上可以是任何模块, 目前作者实现的是在Linear, Embeding, Conv, Attention(只改其中的q和v)这些模块上加。

具体实现见:https://github.com/microsoft/LoRA/blob/main/loralib/layers.py

拿其中的Linear做个简单分析吧, 其他都是类似的。

class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights


class Linear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = True       

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)            
            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)

 

在Linear层的实现上,多继承了一个LoRALayer, LoRALayer中就是设置了一些参数, 最主要的就是上面的讲道的矩阵的秩r了,其他就是一些辅助参数, 如控制训练和推理时主路参数和旁路参数是否合并等等。 在Linear层中, 多定义了A和B两个可训练的参数矩阵, 然后在forward中把主路和旁路输出相加, 基本上就是完全按照原理来的。

3 使用

实际使用LoRA微调时, 也不用自己向上面那样实现了。上面的loralib库已经实现好了, 直接使用就好了。具体而言, 就是把网络中原来使用nn.Linear用loralib库中的Linear替换就可以了, 其他的模块同理。

实际上, 还有更简洁的方式,huggingface pert库很贴心地把各种finetune方式都做了集成, 更加简单和方便。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/883709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

栈存储结构详解

目录 栈存储结构详解 进栈和出栈 栈的具体实现 栈的应用 什么是队列&#xff08;队列存储结构&#xff09; 栈存储结构详解 同顺序表和链表一样&#xff0c;栈也是用来存储逻辑关系为 "一对一" 数据的线性存储结构&#xff0c;如图 1 所示。 图 1 栈存储结构示意…

【string】基本用法

目录 前言: string常用接口 一、string的创建,拼接与拷贝构造 1.创建 2.拼接 3.拷贝构造 二、string遍历 方式一&#xff1a;operator[ ]重载 方式二&#xff1a;迭代器 1.正向迭代器&#xff1a; 2.反向迭代器 3.const正向迭代器 4.const反向迭代器 方式三&#…

PyQt5资源的加载和使用,即如何使用Pyrcc

1、打开QtDesigner&#xff0c;选择编辑资源 2、新建资源文件&#xff0c;随便找个地方保存 3、按照自己的喜好命名&#xff0c;然后添加资源 4、保存并退出 5、我们创建一个QLabel&#xff0c;在这里添加资源 6、我们保存界面文件&#xff0c;并编译为py文件&#xff0c;然后…

【C语言】调试技巧

目录 一、什么是bug? 二、调试 1.一般调试的步骤 2.Debug 和 Release 三、调试环境准备 四、调试时要查看的信息 1.查看临时变量的值 2.查看内存信息 3.查看调用堆栈 4.查看反汇编信息 5.查看寄存器 五、练习 六、常见的coding技巧 七、const的作用 八、编程常见…

时序预测 | MATLAB实现基于RF随机森林的时间序列预测-递归预测未来(多指标评价)

时序预测 | MATLAB实现基于RF随机森林的时间序列预测-递归预测未来(多指标评价) 目录 时序预测 | MATLAB实现基于RF随机森林的时间序列预测-递归预测未来(多指标评价)预测结果基本介绍程序设计参考资料 预测结果 基本介绍 MATLAB实现基于RF随机森林的时间序列预测-递归预测未来…

Linux命令200例:kill用来终止或者结束进程(常用)

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;全栈领域新星创作者✌。CSDN专家博主&#xff0c;阿里云社区专家博主&#xff0c;2023年6月csdn上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &…

服务链路追踪

一、基础概念 1.背景 对于一个大型的几十个、几百个微服务构成的微服务架构系统&#xff0c;通常会遇到下面一些问题&#xff0c;比如&#xff1a; 如何串联整个调用链路&#xff0c;快速定位问题&#xff1f;如何理清各个微服务之间的依赖关系&#xff1f;如何进行各个微服…

uniapp-微信小程序篇

uniapp-微信小程序篇 一、创建项目(以Vue3TS 项目为示例) 可以通过命令行的方式创建也可以通过HBuilderX进行创建&#xff08;通过HBuilderX创建的项目建议选择最简单的模板&#xff09;&#xff0c;个人建议使用命令行方式。 (1) 命令行方式&#xff1a; npx degit dcloudio…

WebRTC音视频通话-WebRTC视频自定义RTCVideoCapturer相机

WebRTC音视频通话-WebRTC视频自定义RTCVideoCapturer相机 在之前已经实现了WebRTC调用ossrs服务&#xff0c;实现直播视频通话功能。但是在使用过程中&#xff0c;RTCCameraVideoCapturer类提供的方法不能修改及调节相机的灯光等设置&#xff0c;那就需要自定义RTCVideoCaptur…

认识excel篇2之如何快速输入数据

一、快速输入数据&#xff08;快捷键功能的使用&#xff09; 1、鼠标左键填充&#xff1a;复制填充、等差序列填充&#xff08;行、列是一样的&#xff09; 步骤&#xff1a;选中单元格&#xff0c;鼠标放置到单元格右下角待鼠标箭头变成实心十字架&#xff0c;左键向下拖拽&…

List和ObservableCollection和ListBinding在MVVM模式下的对比

List和ObservableCollection和ListBinding在MVVM模式下的对比 List 当对List进行增删操作后&#xff0c;并不会对View进行通知。 //Employee public class Employee : INotifyPropertyChanged {public event PropertyChangedEventHandler? PropertyChanged;public string N…

VMware vCenter 低版本存在未授权任意文件读取漏洞

VMware vCenter 低版本存在未授权任意文件读取漏洞,Arbitrary File Read vulnerability in VMware vCenter(Unauthenticated)。 Poc from:https://twitter.com/ptswarm/status/1316016337550938122 Name :VMware vCenter Server Arbitrary File Read Vulnerability Categor…

STM32F4X NVIC中断概念

STM32F4X NVIC中断概念 CPU查询状态两种方式轮询查询中断查询 STM32有关中断的概念中断向量表系统中断外设中断中断号中断优先级 STM32F4X NVIC控制器NVIC控制器简介NVIC寄存器优先级分组 STM32F4X中断配置优先级分组设置配置外设中断 CPU查询状态两种方式 在讲解中断的概念之…

软件测试常用工具总结(测试管理、单元测试、接口测试、自动化测试、性能测试、负载测试等)

前言 在软件测试的过程中&#xff0c;多多少少都是会接触到一些测试工具&#xff0c;作为辅助测试用的&#xff0c;以提高测试工作的效率&#xff0c;使用好了测试工具&#xff0c;能对测试起到一个很好的作用&#xff0c;同时&#xff0c;有些公司&#xff0c;也会要求掌握一…

docker 学习--03 环境安装(本人使用的win10 Linux也是在win10下模拟)

docker 学习–03 环境安装&#xff08;本人使用的win10 Linux也是在win10下模拟&#xff09; 文章目录 docker 学习--03 环境安装&#xff08;本人使用的win10 Linux也是在win10下模拟&#xff09;[TOC](文章目录) 1. windows10 安装docker1.1 访问官网 点击下载1.2.点击下载的…

算法通关村第九关 | 有序数组转搜索二叉树

有序数组转搜索二叉树 二叉搜索树概念&#xff1a; 若它的左子树不为空&#xff0c;则左子树上的所有节点的值均小于它根节点的值&#xff1b; 若它的右子树不为空&#xff0c;则右子树上所有节点的值均大于它的根节点的值&#xff1b; 它的左右子树也分别为二叉树。下面给出…

ORB_SLAM3 Relocalization

Relocalization Relocalization主要的作用是在跟踪失败时&#xff0c;通过词袋在关键帧数据库KeyFrameDatabase中寻找和当前帧相似的关键帧作为匹配帧&#xff0c;进而恢复当前帧的位姿 计算当前帧的Bow&#xff0c;参考ORB_SLAM3 TrackReferenceKeyFrame中计算当前帧的描述子…

【SLAM】ORBSLAM34macOS: ORBSLAM3 Project 4(for) macOS Platform

文章目录 配置ORBSLAM34macOS 版本运行步骤&#xff1a;版本修复问题记录&#xff1a;编译 fix运行 fix 配置 硬件&#xff1a;MacBook Pro Intel CPU 系统&#xff1a;macOS Ventura 13.4.1 ORBSLAM34macOS 版本 https://github.com/phdsky/ORB_SLAM3/tree/macOS 运行步骤&…

TALKS:解决模型-数据差异的系统框架

资料收集于网络&#xff0c;仅供学习使用 TALKS: A systematic framework for resolving model-data discrepancies https://doi.org/10.1016/j.envsoft.2023.105668 问题现状 TALKS框架 TALKS(Trigger, Articulate, List, Knowledge elicitation, Solve)&#xff0c;作为解…