4. 认识 LoRA:从线性层到注意力机制

news2024/11/13 13:01:38

如果你有使用过 AI 生图,那你一定对 LoRA 有印象,下图来自Civitai LoRA,上面有很多可供下载的LoRA模型。

LoRA 模型示例
你可能也曾疑惑于为什么只导入 LoRA 模型不能生图,读下去,你会解决它。

文章目录

    • 为什么需要 LoRA?
    • LoRA 的核心思想
      • 低秩分解
      • 应用到神经网络中的线性层
        • 参数量对比
        • 直观示意图
      • 代码实现:线性层的 LoRA
    • LoRA 在注意力机制中的应用
      • 代码实现:带 LoRA 的注意力
    • 回到最初的问题:为什么只导入 LoRA 模型不能生图?
    • 总结
    • 推荐阅读

这篇文章将从基础的线性层开始,带你一步步了解 LoRA 的核心思想,并深入探索它在注意力机制中的应用。

LoRA,全称 Low-Rank Adaptation,是一种用于微调大型预训练模型的技术。它的核心思想是通过 低秩分解(常见的形式是奇异值分解)减少微调时的参数量,而不牺牲模型的性能。

论文原文:LoRA: Low-Rank Adaptation of Large Language Models

为什么需要 LoRA?

大型预训练模型的出现,为我们带来了强大的自然语言处理和计算机视觉能力,这是一个推动时代的成功。但大模型的“大”,不仅体现在其参数量上,更体现在我们无法轻松进行微调 : ),全量微调一个预训练大模型的代价非常高,而且一般的设备根本训练不动。而 LoRA 提供了一种高效的微调方法,使得在小型设备上微调大模型成为可能。

根据论文中的描述:

  • Compared to GPT-3 175B fine-tuned with Adam, LoRA can reduce the number of trainable parameters by 10,000 times and the GPU memory requirement by 3 times.

相比于对 GPT-3 175B 模型使用全量参数的微调,LoRA 减少了训练参数量的 10,000 倍,GPU 显存需求的 3 倍。

  • LoRA performs on-par or better than fine-tuning in model quality on RoBERTa, DeBERTa, GPT-2, and GPT-3, despite having fewer trainable parameters, a higher training throughput, and, unlike adapters, no additional inference latency.

LoRA 的可训练参数更少,但在 RoBERTa、DeBERTa、GPT-2 和 GPT-3 上的模型质量与全量微调相当甚至更好,而且不会增加推理延迟。

LoRA 的核心思想

LoRA 的核心在于利用低秩分解来近似模型权重的更新。

低秩分解

在线性代数中,任何矩阵都可以分解为多个低秩矩阵的乘积。例如,一个大的矩阵 W W W 可以近似表示为两个小矩阵 B B B A A A 的乘积:

Δ W = B A \Delta W = BA ΔW=BA

其中:

  • A ∈ R r × in_features A \in \mathbb{R}^{r \times \text{in\_features}} ARr×in_features r r r 是低秩值, in_features \text{in\_features} in_features 是输入特征维度。
  • B ∈ R out_features × r B \in \mathbb{R}^{\text{out\_features} \times r} BRout_features×r out_features \text{out\_features} out_features 是输出特征维度。

通过训练这两个小矩阵,我们可以近似地更新原始权重矩阵 W W W,而无需训练整个大的 W W W

应用到神经网络中的线性层

在线性层中,前向传播的计算为:

y = W x + b y = Wx + b y=Wx+b

其中:

  • x ∈ R in_features x \in \mathbb{R}^{\text{in\_features}} xRin_features 是输入向量。
  • W ∈ R out_features × in_features W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} WRout_features×in_features 是权重矩阵。
  • b ∈ R out_features b \in \mathbb{R}^{\text{out\_features}} bRout_features 是偏置向量。
  • y ∈ R out_features y \in \mathbb{R}^{\text{out\_features}} yRout_features 是输出向量。

在微调过程中,通常需要更新 W W W b b b。但在 LoRA 中,我们可以冻结原始的 W W W,仅仅在其基础上添加一个可训练的增量 Δ W \Delta W ΔW

y = ( W + Δ W ) x + b y = (W + \Delta W)x + b y=(W+ΔW)x+b

其中:

Δ W = B A \Delta W = BA ΔW=BA

通过训练 A A A B B B,我们大大减少了需要更新的参数数量。

参数量对比

假设(回归论文的符号):

  • in_features = d \text{in\_features} = d in_features=d
  • out_features = k \text{out\_features} = k out_features=k
  • 低秩值为 r r r(通常 r ≪ min ⁡ ( d , k ) r \ll \min(d, k) rmin(d,k)

全量微调:

  • 需要训练的参数数量为 k × d + k k \times d + k k×d+k,其中:
    • k × d k \times d k×d 是权重矩阵 W W W 的参数数量。
    • k k k 是偏置向量 b b b 的参数数量。

使用 LoRA 微调:

  • 需要训练的参数数量为 r × d + k × r + k r \times d + k \times r + k r×d+k×r+k,其中:
    • r × d r \times d r×d 是矩阵 A A A 的参数数量。
    • k × r k \times r k×r 是矩阵 B B B 的参数数量。
    • k k k 是偏置向量 b b b 的参数数量。

参数量减少的比例:

  • 计算:
    减少比例 = LoRA 参数量 全量微调参数量 = r d + k r + k k d + k \text{减少比例} = \frac{\text{LoRA 参数量}}{\text{全量微调参数量}} = \frac{r d + k r + k}{k d + k} 减少比例=全量微调参数量LoRA 参数量=kd+krd+kr+k

    为了简化,我们可以将偏置参数忽略(因为它们相对于权重参数来说数量很小),得到:

    减少比例 ≈ r ( d + k ) k d \text{减少比例} \approx \frac{r(d + k)}{k d} 减少比例kdr(d+k)

    如果假设 k ≈ d k \approx d kd,则有:

    减少比例 ≈ r ( 2 d ) d 2 = 2 r d \text{减少比例} \approx \frac{r(2d)}{d^2} = \frac{2r}{d} 减少比例d2r(2d)=d2r

    所以,当 k ≈ d k \approx d kd 时,参数减少比例近似为 2 r d \frac{2r}{d} d2r

  • 由于 r ≪ d r \ll d rd,所以参数量大幅减少。

举例说明:

假设:

  • 输入特征维度 in_features = d = 1024 \text{in\_features} = d = 1024 in_features=d=1024
  • 输出特征维度 out_features = k = 1024 \text{out\_features} = k = 1024 out_features=k=1024
  • 低秩值 r = 4 r = 4 r=4

全量微调参数量:

  • 权重参数: 1024 × 1024 = 1 , 048 , 576 1024 \times 1024 = 1,048,576 1024×1024=1,048,576
  • 偏置参数: 1024 1024 1024
  • 总参数量: 1 , 048 , 576 + 1024 = 1 , 049 , 600 1,048,576 + 1024 = 1,049,600 1,048,576+1024=1,049,600

使用 LoRA 微调参数量:

  • 矩阵 A A A 参数: 4 × 1024 = 4 , 096 4 \times 1024 = 4,096 4×1024=4,096
  • 矩阵 B B B 参数: 1024 × 4 = 4 , 096 1024 \times 4 = 4,096 1024×4=4,096
  • 偏置参数: 1024 1024 1024
  • 总参数量: 4 , 096 + 4 , 096 + 1024 = 9 , 216 4,096 + 4,096 + 1024 = 9,216 4,096+4,096+1024=9,216

参数量对比:

  • 全量微调: 1 , 049 , 600 1,049,600 1,049,600 参数
  • LoRA 微调: 9 , 216 9,216 9,216 参数
  • 参数减少比例: 9 , 216 1 , 049 , 600 ≈ 0.0088 \frac{9,216}{1,049,600} \approx 0.0088 1,049,6009,2160.0088

也就是说,使用 LoRA 后,参数量减少了约 114 114 114,即参数量仅为原来的 0.88 % 0.88\% 0.88%

直观示意图

论文中的这张图直观地展示了这一点,为了更符合直觉,我们使用在后续继续使用 in_features \text{in\_features} in_features out_features \text{out\_features} out_features 替代 d d d k k k 进行描述:

LoRA 原理示意图

代码实现:线性层的 LoRA

下面我们来实现一个带有 LoRA 的线性层。

import torch
import torch.nn as nn

class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r):
        super(LoRALinear, self).__init__()
        self.in_features = in_features  # 对应 d
        self.out_features = out_features  # 对应 k
        self.r = r  # 低秩值

        # 原始权重矩阵,冻结
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.weight.requires_grad = False  # 冻结

        # LoRA 部分的参数,初始化为零
        self.A = nn.Parameter(torch.zeros(r, in_features))  # 形状为 (r, d)
        self.B = nn.Parameter(torch.zeros(out_features, r))  # 形状为 (k, r)

        # 偏置项,可选
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x):
        # 原始部分
        original_output = torch.nn.functional.linear(x, self.weight, self.bias)
        # LoRA 增量部分
        delta_W = torch.matmul(self.B, self.A)  # 形状为 (k, d)
        lora_output = torch.nn.functional.linear(x, delta_W)
        # 总输出
        return original_output + lora_output

在这个实现中,self.weight 是原始的权重矩阵,被冻结不参与训练。self.Aself.B 是可训练的低秩矩阵。

LoRA 在注意力机制中的应用

Transformer 模型的核心是注意力机制,其中涉及到 Query、Key、Value 的计算,这些都是线性变换。

在标准的注意力机制中,计算公式为:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

其中 Q Q Q K K K V V V 的计算为:

Q = X Q W Q , K = X K W K , V = X V W V Q = X_Q W_Q, \quad K = X_K W_K, \quad V = X_V W_V Q=XQWQ,K=XKWK,V=XVWV

X Q X_Q XQ X K X_K XK X V X_V XV 的输入可以相同,也可以不同。例如,在 Cross-Attention 中,解码器的隐藏状态作为 X Q X_Q XQ,编码器的输出作为 X K X_K XK X V X_V XV

LoRA 可以应用到 W Q W_Q WQ W K W_K WK W V W_V WV 上,采用与线性层类似的方式。

代码实现:带 LoRA 的注意力

下面我们实现一个带有 LoRA 的单头注意力层。

import torch
import torch.nn as nn

class LoRAAttention(nn.Module):
    def __init__(self, embed_dim, r):
        super(LoRAAttention, self).__init__()
        self.embed_dim = embed_dim  # 对应 d_model
        self.r = r  # 低秩值

        # 原始的 QKV 权重,冻结
        self.W_Q = nn.Linear(embed_dim, embed_dim)
        self.W_K = nn.Linear(embed_dim, embed_dim)
        self.W_V = nn.Linear(embed_dim, embed_dim)
        self.W_O = nn.Linear(embed_dim, embed_dim)

        for param in self.W_Q.parameters():
            param.requires_grad = False
        for param in self.W_K.parameters():
            param.requires_grad = False
        for param in self.W_V.parameters():
            param.requires_grad = False

        # LoRA 的 Q 部分
        self.A_Q = nn.Parameter(torch.zeros(r, embed_dim))  # 形状为 (r, d_model)
        self.B_Q = nn.Parameter(torch.zeros(embed_dim, r))  # 形状为 (d_model, r)

        # LoRA 的 K 部分
        self.A_K = nn.Parameter(torch.zeros(r, embed_dim))
        self.B_K = nn.Parameter(torch.zeros(embed_dim, r))

        # LoRA 的 V 部分
        self.A_V = nn.Parameter(torch.zeros(r, embed_dim))
        self.B_V = nn.Parameter(torch.zeros(embed_dim, r))

    def forward(self, query, key, value):
        """
        query, key, value: 形状为 (batch_size, seq_length, embed_dim)
        """
        # 计算原始的 Q、K、V
        Q = self.W_Q(query)  # (batch_size, seq_length, embed_dim)
        K = self.W_K(key)
        V = self.W_V(value)

        # 计算 LoRA 增量部分
        delta_Q = torch.matmul(query, self.A_Q.t())  # (batch_size, seq_length, r)
        delta_Q = torch.matmul(delta_Q, self.B_Q.t())  # (batch_size, seq_length, embed_dim)
        delta_K = torch.matmul(key, self.A_K.t())
        delta_K = torch.matmul(delta_K, self.B_K.t())
        delta_V = torch.matmul(value, self.A_V.t())
        delta_V = torch.matmul(delta_V, self.B_V.t())

        # 更新后的 Q、K、V
        Q = Q + delta_Q
        K = K + delta_K
        V = V + delta_V

        # 计算注意力得分
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_dim ** 0.5)
        attn_weights = torch.nn.functional.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)

        # 输出层
        output = self.W_O(context)

        return output

代码解释:

  • 原始权重W_QW_KW_V 被冻结,不参与训练。
  • LoRA 参数A_QB_QA_KB_KA_VB_V 是可训练的低秩矩阵。
  • 前向传播
    • 首先计算原始的 Q、K、V。
    • 然后计算 LoRA 的增量部分,并添加到原始的 Q、K、V 上。
    • 接着按照注意力机制进行计算。

回到最初的问题:为什么只导入 LoRA 模型不能生图?

在理解了 LoRA 的核心思想后,相信你已经可以回答。

原因是:LoRA 模型只是对原始模型的权重更新进行了低秩近似,存储了权重的增量部分 Δ W \Delta W ΔW,而不是完整的模型权重 W W W

  • LoRA 模型本身不包含原始模型的权重参数,只包含微调时训练的增量参数 A A A B B B
  • 在推理(如生成图像)时,必须将 LoRA 的增量参数与原始预训练模型的权重相加,才能得到完整的模型权重。
  • 因此,仅仅加载 LoRA 模型是无法进行推理的,必须结合原始的预训练模型一起使用。

打个比方,LoRA 模型就像是给一幅画添加的“修改指令”,但这些指令需要在原始画作的基础上才能生效。如果你只有修改指令(LoRA 模型),却没有原始的画作(预训练模型),那么你就无法得到最终的作品。

所以,要使用 LoRA 模型生成图像,必须同时加载预训练的基础模型和对应的 LoRA 模型。

总结

LoRA 通过将权重更新分解为两个低秩矩阵 A A A B B B 的乘积,极大地减少了微调过程中需要训练的参数量。在不牺牲模型性能的前提下,降低了计算资源的需求,使得在资源受限的环境中微调大型预训练模型成为可能。

这真的是一个很理所当然的想法,不由得感叹数学的重要性。

推荐阅读

题外话:LoRA 的灵感其实涉及到了线性代数的知识,对于想深入学习线性代数的同学们,推荐一本很好的自学教材:《线性代数及其应用》作者是 David C. Lay、Steven R. Lay 和 Judi J. McDonald,英文名为:《Linear Algebra and Its Applications》。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2139816.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

预训练数据指南:衡量数据年龄、领域覆盖率、质量和毒性的影响

前言 原论文:A Pretrainer’s Guide to Training Data: Measuring the Effects of Data Age, Domain Coverage, Quality, & Toxicity 摘要 预训练是开发高性能语言模型(LM)的初步和基本步骤。尽管如此,预训练数据的设计却严…

STM32 HAL freertos零基础(十一)中断管理

1、简介 在FreeRTOS中,中断管理是一个重要的方面,尤其是在嵌入式系统中。正确地处理中断可以确保系统的实时响应能力,并且能够在中断服务程序(ISR)中执行关键操作。FreeRTOS提供了一些机制来帮助开发者管理中断,并确保在多任务环境下中断处理的安全性和高效性。 任何中…

【AI大模型】Transformer模型:Postion Embedding概述、应用场景和实现方式的详细介绍。

一、位置嵌入概述 \1. 什么是位置嵌入? 位置嵌入是一种用于编码序列中元素位置信息的技术。在Transformer模型中,输入序列中的每个元素都会被映射到一个高维空间中的向量表示。然而,传统的自注意力机制并不包含位置信息,因此需要…

3CCD的工作原理

昨天看编辑送的一本《计算机视觉》中3CCD的工作原理错了,其实是百度百科错了,所以我想有人就照搬照抄错了。专业问题不要问百度,百度就是骗子一样的存在,这么多年就从来没有把心思放在做事上。3CCD通过光学棱镜分光后就已经是单色…

智能摄像头MP4格式化恢复方法

如果说生孩子扎堆,那很显然最近智能摄像头多碎片的恢复也扎堆了,这次恢复的是一个不知名的小品牌。其采用了mp4视频文件方案,不过这个案例的特殊之处在于其感染了病毒且不只一次,我们来看看这个小品牌的智能恢复头格式化的恢复方法…

Oracle发邮件功能:设置的步骤与注意事项?

Oracle发邮件配置教程?如何实现Oracle发邮件功能? Oracle数据库作为企业级应用的核心,提供了内置的发邮件功能,使得数据库管理员和开发人员能够通过数据库直接发送邮件。AokSend将详细介绍如何设置Oracle发邮件功能。 Oracle发邮…

基于web的 BBS论坛管理系统设计与实现

博主介绍:专注于Java .net php phython 小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不然下次找不到哟 我的博客空间发布了1000毕设题目 方便大家学习使用 感兴趣的可以…

Linux 基本使用和 web 程序部署 ( 8000 字 Linux 入门 )

一:Linux 背景知识 1.1. Linux 是什么 Linux 是一个操作系统. 和 Windows 是 “并列” 的关系,经过这么多年的发展, Linux 已经成为世界第一大操作系统,安卓系统本质上就是 Linux. 1.2 Linux 发行版 Linux 严格意义来说只是一个 “操作系…

【楚怡杯】职业院校技能大赛 “云计算应用” 赛项样题三

某企业根据自身业务需求,实施数字化转型,规划和建设数字化平台,平台聚焦“DevOps开发运维一体化”和“数据驱动产品开发”,拟采用开源OpenStack搭建企业内部私有云平台,开源Kubernetes搭建云原生服务平台,选…

高亮下位机温湿度

效果如下: 如何对QTextEditor中的内容进行高亮和格式化显示: 首先我们要自定义一个类WenshiduHighlighter,继承自QSyntaxHighlighter实现构造函数,在构造函数中将需要匹配的正则和对应的格式创建,存到成员变量中重写父类的void h…

DNS应答报文分析

目录 DNS应答以太网数据帧 1. 数据链路层 1.1 以太网首部:(目的MAC地址6字节)(源MAC地址6字节)(帧类型2字节)共14字节 1.2 以太网首部数据 2. 网络层 2.1 IP协议头部共20个字节 2.2 IP协议头部数据 3. 传输层 3.1 UDP头部共8字节 3.2 UDP头部数据 4. 应用层 4.1 D…

低空经济第一站:无人机飞手人才培养技术详解

在低空经济蓬勃发展的背景下,无人机飞手作为直接操作者和应用者,其人才培养技术成为推动这一新兴经济形态持续健康发展的关键。以下是对无人机飞手人才培养技术的详细解析: 一、培养目标 无人机飞手的培养旨在培养具备扎实无人机操作技能、…

_Array类,类似于Vector,其实就是_string

例子&#xff1a; using namespace lf; using namespace std;int main() {_Array<int> a(10, -1);_Array<_string> s { _t("one"), _t("two") };_pcn(a);_pcn(s);} 结果&#xff1a; 源代码_Array.h&#xff1a; /***********************…

el-table 的单元格 + 图表 + 排序

<el-table border :data"tableDataThree" height"370px" style"width: 100%"><el-table-column :key"activeName 8" width"50" type"index" label"序号" align"center"></el…

macOS系统Homebrew工具安装及使用

1.打开Homebrew — The Missing Package Manager for macOS (or Linux) 2.复制安装命令到终端执行 复制 执行 3. 开始自动安装过程 4.安装成功 5.使用brew安装wget工具

第L6周:机器学习-随机森林(RF)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 目标&#xff1a; 1.什么是随机森林&#xff08;RF&#xff09; 随机森林&#xff08;Random Forest, RF&#xff09;是一种由 决策树 构成的 集成算法 &#…

WebSocket vs. Server-Sent Events:选择最适合你的实时数据流技术

引言&#xff1a; 在当今这个信息爆炸的时代&#xff0c;用户对于网页应用的实时性要求越来越高。从即时通讯到在线游戏&#xff0c;再到实时数据监控&#xff0c;WebSocket技术因其能够实现浏览器与服务器之间的全双工通信而受到开发者的青睐。 WebSocket技术为现代Web应用…

java计算机毕设课设—电子政务网系统(附源码、文章、相关截图、部署视频)

这是什么系统&#xff1f; 资源获取方式在最下方 java计算机毕设课设—电子政务网系统(附源码、文章、相关截图、部署视频) 电子政务网系统主要用于提升政府机关的政务管理效率&#xff0c;核心功能包括前台网站展示、留言板管理、后台登录与密码修改、网站公告发布、政府部…

高级Java程序员必备的技术点:你准备好了吗?

在Java编程的世界里&#xff0c;成为一名高级程序员不仅需要深厚的基础知识&#xff0c;还需要掌握一系列高级技术和最佳实践。这些技术点是通向技术专家之路的敲门砖&#xff0c;也是应对复杂项目挑战的利器。本文将探讨高级Java程序员必备的技术点&#xff0c;帮助你自我提升…

VS code 安装使用配置 Continue

Continue 插件介绍 Continue 是一款高效的 VS Code 插件&#xff0c;提供类似 GitHub Copilot 的功能&#xff0c;旨在提升开发者的编程效率。其配置简单&#xff0c;使用体验流畅&#xff0c;深受开发者喜爱。 主要功能特点 智能代码补全 Continue 能够基于当前代码上下文生…