【llm对话系统】大模型 Llama 源码分析之 LoRA 微调

news2025/2/2 14:03:57

1. 引言

微调 (Fine-tuning) 是将预训练大模型 (LLM) 应用于下游任务的常用方法。然而,直接微调大模型的所有参数通常需要大量的计算资源和内存。LoRA (Low-Rank Adaptation) 是一种高效的微调方法,它通过引入少量可训练参数,固定预训练模型的权重,从而在保持性能的同时大大减少了计算开销。

本文将深入分析 LoRA 的原理,并结合 Llama 源码解读其实现逻辑,最后探讨 LoRA 的优势。

2. LoRA 原理

LoRA 的核心思想是:预训练模型中已经包含了大量的低秩 (low-rank) 特征,微调时只需要对这些低秩特征进行微调即可。

具体来说,LoRA 假设权重更新矩阵 ΔW 也是低秩的。对于一个预训练的权重矩阵 W ∈ R^(d×k),LoRA 将其更新表示为:

W' = W + ΔW = W + BA

其中:

  • W 是预训练的权重矩阵。
  • ΔW 是权重更新矩阵。
  • B ∈ R^(d×r)A ∈ R^(r×k) 是两个低秩矩阵,r 远小于 dkr 被称为 LoRA 的秩 (rank)。

在训练过程中,W 被冻结,只有 AB 是可训练的。

直观理解:

可以将 W 看作一个编码器,将输入 x 编码成一个高维表示 Wx。LoRA 认为,在微调过程中,我们不需要完全改变这个编码器,只需要通过 BA 对其进行一个低秩的调整即可。

3. Llama 中 LoRA 的实现

虽然 Llama 官方代码没有直接集成 LoRA,但我们可以使用一些流行的库 (例如 peft by Hugging Face) 来实现 Llama 的 LoRA 微调。peft 库提供了 LoraConfigget_peft_model 等工具,可以方便地将 LoRA 应用于各种 Transformer 模型。

3.1 使用 peft 库实现 Llama 的 LoRA 微调

以下是一个使用 peft 库实现 Llama 的 LoRA 微调的简化示例:

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType

# 加载预训练的 Llama 模型和分词器
model_name = "meta-llama/Llama-2-7b-hf"  # 假设使用 Llama 2 7B
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# LoRA 配置
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,  # LoRA 的秩
    lora_alpha=32,  # LoRA 的缩放因子
    lora_dropout=0.1,  # Dropout 比例
    target_modules=["q_proj", "v_proj"], # 需要应用 LoRA 的模块
)

# 获取支持 LoRA 的模型
model = get_peft_model(model, config)

# 打印可训练参数的比例
model.print_trainable_parameters()

# ... (加载数据,进行训练) ...

代码解释:

  1. 加载预训练模型:使用 transformers 库加载预训练的 Llama 模型和分词器。
  2. LoRA 配置:创建一个 LoraConfig 对象,指定 LoRA 的配置参数:
    • task_type:任务类型,这里是因果语言模型 (Causal Language Modeling)。
    • r:LoRA 的秩。
    • lora_alpha:LoRA 的缩放因子,用于控制 LoRA 模块的权重。
    • lora_dropout:Dropout 比例。
    • target_modules: 指定需要应用 LoRA 的模块, 通常是注意力层中的 q_proj, v_proj, 还可以是k_proj, o_proj, gate_proj, up_proj, down_proj等。不同的模型需要根据实际情况配置。
  3. 获取支持 LoRA 的模型:使用 get_peft_model 函数将原始的 Llama 模型转换为支持 LoRA 的模型。
  4. 打印可训练参数:使用 model.print_trainable_parameters() 可以查看模型中可训练参数的比例,通常 LoRA 的可训练参数比例非常小。

3.2 peft 库中 LoRA 的实现细节 (部分)

peft 库中 LoraModel 类的部分代码 (为了清晰起见,已进行简化):

class LoraModel(torch.nn.Module):
    # ...
    def _find_and_replace(self, model):
        # ... (遍历模型的每一层) ...
        if isinstance(module, nn.Linear) and name in self.config.target_modules:
            new_module = Linear(
                module.in_features,
                module.out_features,
                bias=module.bias is not None,
                r=self.config.r,
                lora_alpha=self.config.lora_alpha,
                lora_dropout=self.config.lora_dropout,
            )
            # ... (将原模块的权重赋值给新模块) ...
        # ...

class Linear(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        **kwargs,
    ):
        super().__init__(in_features, out_features, **kwargs)
        # LoRA 参数
        self.r = r
        self.lora_alpha = lora_alpha
        # 初始化 A 和 B
        if r > 0:
            self.lora_A = nn.Parameter(torch.randn(r, in_features))
            self.lora_B = nn.Parameter(torch.zeros(out_features, r)) # B 初始化为全 0
            self.scaling = self.lora_alpha / self.r

    def forward(self, x: torch.Tensor):
        result = F.linear(x, self.weight, bias=self.bias) # W @ x
        if self.r > 0:
            result += (
                self.lora_B @ self.lora_A @ x.transpose(-2, -1) # (B @ A) @ x
            ).transpose(-2, -1) * self.scaling
        return result

代码解释:

  1. _find_and_replace 函数:遍历模型的每一层,找到需要应用 LoRA 的线性层 (例如,q_proj, v_proj),并将其替换为 Linear 层。
  2. Linear 类:继承自 nn.Linear,并添加了 LoRA 的参数 lora_Alora_B
    • lora_A 初始化为随机值。
    • lora_B 初始化为全 0,这是为了保证在训练开始时,LoRA 部分的输出为 0,不影响预训练模型的原始行为。
    • scaling 是一个缩放因子,用于控制 LoRA 模块的权重。
  3. forward 函数:
    • F.linear(x, self.weight, bias=self.bias) 计算原始的线性变换 W @ x
    • (self.lora_B @ self.lora_A @ x.transpose(-2, -1)).transpose(-2, -1) * self.scaling 计算 LoRA 部分的输出 (B @ A) @ x,并乘以缩放因子。
    • 将两者相加,得到最终的输出。

4. LoRA 的优势

  • 高效的参数利用:LoRA 只需微调少量的参数 (A 和 B),而冻结了预训练模型的大部分参数,大大减少了训练时的内存占用和计算开销。
  • 快速的训练速度:由于可训练参数较少,LoRA 的训练速度通常比全量微调快得多。
  • 防止过拟合:LoRA 的低秩约束起到了一定的正则化作用,有助于防止过拟合。
  • 性能相当:在许多任务上,LoRA 可以达到与全量微调相当的性能。
  • 易于部署:训练完成后,可以将 WBA 相加,得到新的权重矩阵 W',然后像使用原始的预训练模型一样进行部署,无需额外的计算开销。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2289806.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯刷题DAY2:二维前缀和 一维前缀和 差分数组

闪耀的灯光 📌 题目描述 蓝桥公园是一个适合夜间散步的好地方,公园可以被视为由 n m 个矩形区域构成。每个区域都有一盏灯,初始亮度为 a[i][j]。 小蓝可以选择一个大的矩形区域,并按下开关一次,这将使得该区域内每盏…

C++初阶 -- 手撕string类(模拟实现string类)

目录 一、string类的成员变量 二、构造函数 2.1 无参版本 2.2 有参版本 2.3 缺省值版本 三、析构函数 四、拷贝构造函数 五、c_str函数 六、operator重载 七、size函数 八、迭代器iterator 8.1 正常版本 8.2 const版本 九、operator[] 9.1 正常版本 9.2 const版…

【Unity3D】实现2D角色/怪物死亡消散粒子效果

核心:这是一个Unity粒子系统自带的一种功能,可将粒子生成控制在一个Texture图片网格范围内,并且粒子颜色会自动采样图片的像素点颜色,之后则是粒子编辑出消散效果。 Particle System1物体(爆发式随机速度扩散10000个粒…

85.[1] 攻防世界 WEB easyphp

进入靶场 属于代码审计 <?php // 高亮显示当前 PHP 文件的源代码&#xff0c;常用于调试或展示代码 highlight_file(__FILE__);// 初始化两个标志变量&#xff0c;用于后续条件判断 $key1 0; $key2 0;// 从 GET 请求中获取参数 a 和 b $a $_GET[a]; $b $_GET[b];// 检…

pytorch图神经网络处理图结构数据

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 图神经网络&#xff08;Graph Neural Networks&#xff0c;GNNs&#xff09;是一类能够处理图结构数据的深度学习模型。图结构数据由节点&#xff08;vertices&#xff09;和边&#xff08;edges&#xff09;组成&a…

海外问卷调查,最常用到的渠道查有什么特殊之处

市场调研&#xff0c;包含市场调查和市场研究两个步骤&#xff0c;是企业和机构根据经营方向而做出的决策问题&#xff0c;最终通过海外问卷调查中的渠道查&#xff0c;来系统地设计、收集、记录、整理、分析、研究市场反馈的工作流程。 市场调研的工作流程包括&#xff1a;确…

【Uniapp-Vue3】解决uni-popup弹窗在安全区显示透明问题

我们在使用uni-popup时&#xff0c;如果想要给弹出内容添加一个背景颜色&#xff0c;我们会发现在安全区域是不显示该背景颜色的。 首先根据如下的目录结构找到uni-popup.vue文件 在该文件中找到bottom配置&#xff0c;将红箭头所指代码注释掉 下面的安全区域就没有了&#xff…

项目练习:重写若依后端报错cannot be cast to com.xxx.model.LoginUser

文章目录 一、情景说明二、解决办法 一、情景说明 在重写若依后端服务的过程中 使用了Redis存放LoginUser对象数据 那么&#xff0c;有存就有取 在取值的时候&#xff0c;报错 二、解决办法 方法1、在TokenService中修改如下 getLoginUser 方法中&#xff1a;LoginUser u…

核心集:DeepCore: A Comprehensive Library for CoresetSelection in Deep Learning

目录 一、TL&#xff1b;DR 二、为什么研究核心集&#xff1f; 三、问题定义和如何做 3.1 问题定义 3.2 业界方法 3.2.1 基于几何的方法 3.2.2 基于不确定性的方法 3.2.3 基于误差/损失的方法 3.2.5 GraNd 和 EL2N 分数 3.2.6 重要性采样 3.2.7 基于决策边界的办法 …

Hot100之矩阵

73矩阵置零 题目 思路解析 收集0位置所在的行和列 然后该行全部初始化为0 该列全部初始化为0 代码 class Solution {public void setZeroes(int[][] matrix) {int m matrix.length;int n matrix[0].length;List<Integer> list1 new ArrayList<>();List<…

可视化相机pose colmap形式的相机内参外参

目录 内参外参转换 可视化相机pose colmap形式的相机内参外参 内参外参转换 def visualize_cameras(cameras, images):fig plt.figure()ax fig.add_subplot(111, projection3d)for image_id, image_data in images.items():qvec image_data[qvec]tvec image_data[tvec]#…

数据库内存与Buffer Pool

数据库内存与Buffer Pool 文章目录 数据库内存与Buffer Pool一&#xff1a;MySQL内存结构1&#xff1a;MySQL工作组件2&#xff1a;工作线程的本地内存3&#xff1a;共享内存区域4&#xff1a;存储引擎缓冲区 二&#xff1a;InnoDB的核心&#xff1a;Buffer Pool1&#xff1a;数…

程序员学英文之At the Airport Customs

Dialogue-1 Making Airline Reservation预定机票 My cousin works for Xiamen Airlines. 我表哥在厦航上班。I’d like to book an air ticket. 我想预定一张机票。Don’t judge a book by its cover. 不要以貌取人。I’d like to book / re-serve a table for 10. 我想预定一…

Redis代金卷(优惠卷)秒杀案例-单应用版

优惠卷表:优惠卷基本信息,优惠金额,使用规则 包含普通优惠卷和特价优惠卷(秒杀卷) 优惠卷的库存表:优惠卷的库存,开始抢购时间,结束抢购时间.只有特价优惠卷(秒杀卷)才需要填写这些信息 优惠卷订单表 卷的表里已经有一条普通优惠卷记录 下面首先新增一条秒杀优惠卷记录 { &quo…

51单片机 01 LED

一、点亮一个LED 在STC-ISP中单片机型号选择 STC89C52RC/LE52RC&#xff1b;如果没有找到hex文件&#xff08;在objects文件夹下&#xff09;&#xff0c;在keil中options for target-output- 勾选 create hex file。 如果要修改编程 &#xff1a;重新编译-下载/编程-单片机重…

MusicFree-开源的第三方音乐在线播放和下载工具, 支持歌单导入[对标落雪音乐]

MusicFree 链接&#xff1a;https://pan.xunlei.com/s/VOI0RrVLTTWE9kkpt0U7ofGBA1?pwd4ei6#

消息队列篇--原理篇--常见消息队列总结(RabbitMQ,Kafka,ActiveMQ,RocketMQ,Pulsar)

1、RabbitMQ 特点&#xff1a; AMQP协议&#xff1a;RabbitMQ是基于AMQP&#xff08;高级消息队列协议&#xff09;构建的&#xff0c;支持多种消息传递模式&#xff0c;如发布/订阅、路由、RPC等。多语言支持&#xff1a;支持多种编程语言的客户端库&#xff0c;包括Java、P…

nacos 配置管理、 配置热更新、 动态路由

文章目录 配置管理引入jar包添加 bootstrap.yaml 文件配置在application.yaml 中添加自定义信息nacos 配置信息 配置热更新采用第一种配置根据服务名确定配置文件根据后缀确定配置文件 动态路由DynamicRouteLoaderNacosConfigManagerRouteDefinitionWriter 路由配置 配置管理 …

(笔记+作业)书生大模型实战营春节卷王班---L0G2000 Python 基础知识

学员闯关手册&#xff1a;https://aicarrier.feishu.cn/wiki/QtJnweAW1iFl8LkoMKGcsUS9nld 课程视频&#xff1a;https://www.bilibili.com/video/BV13U1VYmEUr/ 课程文档&#xff1a;https://github.com/InternLM/Tutorial/tree/camp4/docs/L0/Python 关卡作业&#xff1a;htt…

SpringBoot中Excel表的导入、导出功能的实现

文章目录 一、easyExcel简介二、Excel表的导出2.1 添加 Maven 依赖2.2 创建导出数据的实体类4. 编写导出接口5. 前端代码6. 实现效果 三、excel表的导出1. Excel表导入的整体流程1.1 配置文件存储路径 2. 前端实现2.1 文件上传组件 2.2 文件上传逻辑3. 后端实现3.1 文件上传接口…