LoRA技术详解---附实战代码

news2024/12/22 19:28:02

LoRA技术详解—附实战代码

image-20241008195025957

引言

随着大语言模型规模的不断扩大,如何高效地对这些模型进行微调成为了一个重要的技术挑战。Low-Rank Adaptation(LoRA)技术应运而生,它通过巧妙的低秩分解方法,显著减少了模型微调时需要训练的参数数量,同时保持了良好的性能表现。本文将深入介绍LoRA的原理,并通过详细的PyTorch代码实现来展示其工作机制。

LoRA的核心原理

基本思想

LoRA的核心思想是:在保持预训练模型权重不变的情况下,通过向每个转换器层添加低秩矩阵来实现模型的适应性调整。具体来说,对于原始的权重矩阵 W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0Rd×k,LoRA引入了如下的更新机制:

W = W 0 + Δ W = W 0 + B A W = W_0 + \Delta W = W_0 + BA W=W0+ΔW=W0+BA

其中:

  • B ∈ R d × r B \in \mathbb{R}^{d \times r} BRd×r
  • A ∈ R r × k A \in \mathbb{R}^{r \times k} ARr×k
  • r r r 是一个远小于 d d d k k k 的秩

关键特征

  1. 参数高效性:通过引入低秩分解,LoRA显著减少了需要训练的参数量。
  2. 初始化策略 Δ W \Delta W ΔW 在训练开始时被初始化为零矩阵,确保了模型从原始性能开始逐步调整。
  3. 可扩展性:可以轻松应用于不同类型的层,如线性层和嵌入层。

实现细节分析

1. 初始化策略

LoRA的初始化策略非常关键:

  • 对于线性层,矩阵A使用kaiming均匀初始化
  • 对于嵌入层,矩阵A使用正态分布初始化
  • 两种情况下,矩阵B都初始化为零,确保训练开始时 Δ W = B A = 0 \Delta W = BA = 0 ΔW=BA=0

2. 缩放因子

缩放因子 α r \frac{\alpha}{r} rα的引入有两个主要作用:

  1. 控制LoRA更新的幅度
  2. 使得不同秩r的实验结果更具可比性

3. 前向传播

在前向传播中,LoRA的更新通过以下步骤实现:

  1. 计算原始层的输出
  2. 计算低秩更新: ( x @ A T @ B T ) ∗ α r (x @ A^T @ B^T) * \frac{\alpha}{r} (x@AT@BT)rα
  3. 将两部分结果相加

PyTorch实现详解

LoRA线性层实现

class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool, 
                 r: int, alpha: int = None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 设置缩放因子
        if alpha is None:
            alpha = r
        self.scaling = alpha / r
        
        # 原始权重(冻结)
        self.weight = nn.Parameter(torch.empty((out_features, in_features)))
        self.weight.requires_grad = False
        
        # 偏置项处理
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
            self.bias.requires_grad = False
        else:
            self.bias = None
            
        # LoRA参数初始化
        self.lora_a = nn.Parameter(torch.empty((r, in_features)))
        self.lora_b = nn.Parameter(torch.empty((out_features, r)))
        
        # 初始化
        with torch.no_grad():
            nn.init.kaiming_uniform_(self.lora_a, a=5 ** 0.5)
            nn.init.zeros_(self.lora_b)

    def forward(self, x: torch.Tensor):
        # 原始线性变换
        result = nn.functional.linear(x, self.weight, bias=self.bias)
        # 添加LoRA部分
        result += (x @ self.lora_a.T @ self.lora_b.T) * self.scaling
        return result

LoRA嵌入层实现

class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int,
                 r: int, alpha: int = None):
        super().__init__()
        
        # 设置缩放因子
        if alpha is None:
            alpha = r
        self.scaling = alpha / r
        
        # 原始嵌入权重(冻结)
        self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
        self.weight.requires_grad = False
        
        # LoRA参数初始化
        self.lora_a = nn.Parameter(torch.empty((r, num_embeddings)))
        self.lora_b = nn.Parameter(torch.empty((embedding_dim, r)))
        
        # 初始化
        with torch.no_grad():
            nn.init.normal_(self.lora_a)
            nn.init.zeros_(self.lora_b)

    def forward(self, x: torch.Tensor):
        # 原始嵌入查找
        result = nn.functional.embedding(x, self.weight)
        # 添加LoRA部分
        result += (nn.functional.embedding(x, self.lora_a.T) @ self.lora_b.T) * self.scaling
        return result

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2198315.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UNIAPP popper气泡弹层【unibest框架下】vue3+typescript

看了下市场的代码,要么写的不怎么好,要么过于复杂。于是把市场的代码下下来了自己改。200行代码撸了个弹出层组件。兼容H5和APP。 功能: 1)只支持上下左右4个方向的弹层不支持侧边靠齐 2)不对屏幕边界适配 3)支持弹层外边点击自动隐藏 4)支持…

重学SpringBoot3-集成Redis(八)之限时任务(延迟队列)

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-集成Redis(八)之限时任务(延迟队列) 1. 延迟任务的场景2. Redis Sorted Set基本原理3. 使用 Redis Sorte…

粗糙表面的仿真和处理软件

首款基于粗糙表面的仿真和处理软件,该软件具有三种方法,主要是二维数字滤波法,相位频谱法和共轭梯度法。可以分别仿真具有高斯和非高斯分布的粗糙表面,其中非高斯表面利用Johnson转换系统进行变换给定偏度和峰度。对生成的粗糙表面…

Mysql高级篇(下)——数据库备份与恢复

Mysql高级篇(下)——数据库备份与恢复 一、物理备份与逻辑备份1、物理备份2、逻辑备份3、对比4、总结 二、mysqldump实现逻辑备份1、mysqldump 常用选项2、mysqldump 逻辑备份语法(1)备份一个数据库(2)备份…

linux自动挂载tf卡

本人使用的是armbian系统,ssh工具使用的是finalshell,挂载的是一张64G TF卡。 1.查看系统所检测到的磁盘,这里的 sda1检测到的硬盘但是没有被挂载 lsblk //查看信息 2.在根目录新建一个目录tfcard用于挂载硬盘,命令如下&#xf…

【万字长文】Word2Vec计算详解(一)

【万字长文】Word2Vec计算详解(一) 写在前面 本文用于记录本人学习NLP过程中,学习Word2Vec部分时的详细过程,本文与本人写的其他文章一样,旨在给出Word2Vec模型中的详细计算过程,包括每个模块的计算过程&a…

电商选品/跟卖| 亚马逊商品类爬取

电商跟卖,最重要是了解哪些商品可以卖, 哪些商品不能卖, 为了更好了解商品信息,我们会经常爬取商品类目的信息. 需求 亚马逊类目信息链接爬虫 打开亚马逊类目信息地址 https://www.amazon.com/gp/new-releases/automotive/refzg_bsnr_nav_automotive_0 一直递归下去&#x…

云原生(四十七) | PHP软件安装部署

文章目录 PHP软件安装部署 一、PHP软件部署步骤 二、安装与配置PHP PHP软件安装部署 一、PHP软件部署步骤 第一步:安装 EPEL 仓库 与 Remi仓库 第二步:启用 Remi 仓库 第三步:安装 PHP、PHP-FPM 第四步:启动并开机启用 PH…

10.8 sql语句查询(未知的)

1.查询结果去重 关键字:distinct (放在查询的后面) AC: select distinct university from user_profile 2.查询结果限制返回行数 关键字:limit AC: select device_id from user_profile limit 0,2 3.将查询后的列重新命名 关键字:as AC: select device_id as user_infos…

wildcard使用教程,解决绝大多数普通人的海外支付难题

许多人可能已经注意到,国外的一些先进AI工具对国内用户并不开放。而想要使用这些工具,我们通常会面临两个主要障碍:一是网络访问的限制,二是支付问题。网络问题很容易解决,难的是如何解决在国内充值海外软件。 今天给大家推荐一个工具——wildcard,用它…

【CSS in Depth 2 精译_046】7.1 CSS 响应式设计中的移动端优先设计原则(下)

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第一章 层叠、优先级与继承(已完结) 1.1 层叠1.2 继承1.3 特殊值1.4 简写属性1.5 CSS 渐进式增强技术1.6 本章小结 第二章 相对单位(已完结) 2.1 相对…

StoryMaker: Towards Holistic Consistent Characters in Text-to-image Generation

https://arxiv.org/pdf/2409.12576v1https://github.com/RedAIGC/StoryMaker 问题引入 针对的是文生图的模型,现在已经有方法可以实现指定人物id的情况下进行生成,但是还没有办法保持包括服装、发型等整体,本文主要解决这个问题&#xff1b…

时间卷积网络(TCN)原理+代码详解

目录 一、TCN原理1.1 因果卷积(Causal Convolution)1.2 扩张卷积(Dilated Convolution) 二、代码实现2.1 Chomp1d 模块2.2 TemporalBlock 模块2.3 TemporalConvNet 模块2.4 完整代码示例 参考文献 在理解 TCN 的原理之前&#xff…

GIS后端工程师岗位职责、技术要求和常见面试题

GIS 后端工程师负责设计、开发与维护地理信息系统的后端服务,包括数据存储、处理、分析以及与前端的交互接口等,以实现高效的地理数据管理和功能支持。 GIS 后端工程师岗位职责 一、系统设计与开发 参与地理信息系统(GIS)项目的…

安装 Petalinux

资料准备 ubuntu 22.04: 运行内存8G 存储空间500G Petalinux:2024.1 安装流程 安装依赖 sudo apt-get update sudo apt-get upgrade sudo apt-get install iproute2 sudo apt-get install gawk sudo apt-get install build-essential sudo apt-ge…

7.3 物联网平台-Thingsboard使用教程

物联网平台-Thingsboard使用教程 目录概述需求: 设计思路实现思路分析 免费下载参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for chang…

如何使用ssm实现基于web技术的税务门户网站的实现+vue

TOC ssm820基于web技术的税务门户网站的实现vue 绪论 1.1 研究背景 当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进行科学化&#xff…

基于matlab的语音信号处理

摘要 利用所学习的数字信号处理知识,设计了一个有趣的音效处理系统,首先设计了几种不同的滤波器对声音进行滤波处理,分析了时域和频域的变化,比较了经过滤波处理后的声音与原来的声音有何变化。同时设计实现了语音的倒放&#xf…

从0开始linux(9)——进程(1)进程管理

欢迎来到博主的专栏:从0开始linux 博主ID:代码小豪 文章目录 查看进程进程管理PID与PPIDfork函数 在上一篇中我们了解到:当运行程序时,操作系统会将磁盘中的二进制文件读取到内存当中,程序运行到结束的过程称为进程&am…

【C++ 11】auto 自动类型推导

文章目录 【 1. 基本用法 】【 2. auto 的 应用 】2.0 auto 的限制2.1 简单实例2.2 auto 与指针、引用、const2.4 auto 定义迭代器2.5 auto 用于泛型编程 问题背景 在 C11 之前的版本(C98 和 C 03)中,定义变量或者声明变量之前都必须指明它的…