LoRA 理解

news2024/9/19 21:15:29

LLM的参数量对于时间和显存要求都带来很大的挑战。现存的两种显著范式:

  • 增加adapter:主要问题在于推理时带来的额外计算量和延迟。
  • 优化prompt: 前缀微调(Prefix Tuning)较难优化,而且随着参数量增长性能并非单调变化。

那有什么方法可以 解决这个问题么?:图像生成领域 的 lora

1. 介绍

lora是大模型的低秩适配器,或者就简单的理解为适配器,在图像生成中可以将lora理解为某种图像风格(比如SD社区中的各种漂亮妹子的lora,可插拔式应用,甚至组合式应用实现风格的融合)的适配器,在NLP中可以将其理解为某个任务的适配器(比如最近各种开源chatgpt复现中使用的lora技术,不过限于文本领域的特性,目前组合式应用似乎还不多)。

2. 做法

  • 在原模型旁边增加一个旁路,通过低秩分解(先降维再升维)来模拟参数的更新量;
  • 训练时,原模型固定,只训练降维矩阵A和升维矩阵B;
  • 推理时,可将BA加到原参数上,不引入额外的推理延迟;
  • 初始化,A采用高斯分布初始化,B初始化为全0,保证训练开始时旁路为0矩阵;
  • 可插拔式的切换任务,当前任务W0+B1A1,将lora部分减掉,换成B2A2,即可实现任务切换;

3. 原理

过度参数化的模型实际上位于一个低内在维度的空间,因此lora作者提出假设,微调时权重的变化同样有一个低的"内在维度",因此可以进行低秩矩阵分解。
在这里插入图片描述
大模型(LLM)训练微调综述学习

4. 总结

一句话总结 lora:固定大模型,增加低秩分解的矩阵来适配下游任务。

5. 优点

  • 一个中心模型服务多个下游任务,节省参数存储量
  • 推理阶段不引入额外计算量
  • 与其它参数高效微调方法正交,可有效组合
  • 训练任务比较稳定,效果比较好

6. 缺点

生成任务上效果 欠佳

7. 总览

在这里插入图片描述

8. 实战

安装

pip install loralib

可以选择用loralib中实现的对应层来替换一些层。目前loralib只支持 nn.Linear、nn.Embedding 和 nn.Conv2d。loralib还支持一个 MergedLinear,用于单个 nn.Linear 代表一个以上的层的情况。


# ===== Before =====  
# layer = nn.Linear(in_features, out_features)
    # ===== After ======
    import loralib as lora
    # Add a pair of low-rank adaptation matrices with rank r=16
    layer = lora.Linear(in_features, out_features, r=16)

在训练之前,设置仅LorA模块的参数可被训练

 import loralib as lora
 model = BigModel()
 # This sets requires_grad to False for all parameters without the string "lora_" in their names
 lora.mark_only_lora_as_trainable(model)
 # Training loop
 for batch in dataloader:
 ...

在保存checkpoint时,生成一个仅包含LoRA参数的state_dict

 # ===== Before =====
 # torch.save(model.state_dict(), checkpoint_path)
 # ===== After =====
 torch.save(lora.lora_state_dict(model), checkpoint_path)

当载入checkpoint时,设置strict为False

# Load the pretrained checkpoint first
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
# Then load the LoRA checkpoint
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)

lora.MergedLinear的使用

# ===== Before =====
# qkv_proj = nn.Linear(d_model, 3*d_model)
# ===== After =====
# Break it up (remember to modify the pretrained checkpoint accordingly)
q_proj = lora.Linear(d_model, d_model, r=8)
k_proj = nn.Linear(d_model, d_model)
v_proj = lora.Linear(d_model, d_model, r=8)
# Alternatively, use lora.MergedLinear (recommended)
qkv_proj = lora.MergedLinear(d_model, 3*d_model, r=8, enable_lora=[True, False, True])

可以在调用mark_only_lora_as_trainable时,通过给bias= 传递 "all "或 "lora_only "来标记一些bias为可训练。

# ===== Before =====
# lora.mark_only_lora_as_trainable(model) # Not training any bias vectors
# ===== After =====
# Training all bias vectors associated with modules we apply LoRA to 
lora.mark_only_lora_as_trainable(model, bias='lora_only')
# Alternatively, we can train *all* bias vectors in the model, including LayerNorm biases
lora.mark_only_lora_as_trainable(model, bias='all')
# When saving a checkpoint, use the same bias= ('all' or 'lora_only')
torch.save(lora.lora_state_dict(model, bias='all'), checkpoint_path)
Apply to GPT

参见:LoRA/examples/NLG/src/model.py

class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False):
        super(Attention, self).__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        
        assert n_state % config.n_head == 0
        self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.n_head
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, nx)
        self.c_attn = lora.MergedLinear(
            nx, n_state * 3, 
            r=config.lora_attn_dim, 
            lora_alpha=config.lora_attn_alpha, 
            lora_dropout=config.lora_dropout, 
            enable_lora=[True, False, True], 
            fan_in_fan_out=True,
            merge_weights=False
        )
        self.c_proj = Conv1D(n_state, nx)

        self.config = config
源代码解读

总的来说loralib的源代码比较简洁,可以在LORA/loralib/layers.py 查看

Class LoRALayer
class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights

LoRA layer可以添加到任何一个可以有参数训练的层里。但文章中也提到了we only apply LoRA to Wq and Wv in most experiments for simplicity

LoRA Embedding

(注释在代码块中)

“During training, W0 is frozen and does not receive gradient updates, while A and B contain trainable parameters. ”

“We use a random Gaussian initialization for A and zero for B”

class Embedding(nn.Embedding, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        r: int = 0,
        lora_alpha: int = 1,
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
                           merge_weights=merge_weights)
        
        # Lora 部分
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
            #  scale ∆W x by α/r 
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            #冻结pre-trained 参数
            self.weight.requires_grad = False
        self.reset_parameters()

    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, 'lora_A'):
          #初始化
          # We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training.
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.zeros_(self.lora_A)
            nn.init.normal_(self.lora_B)

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        if self.merge_weights and self.merged:
            # self.merged = Ture
            # Make sure that the weights are not merged
            # weight=weight-B * A * scale 需要剪掉merge的部分
            if self.r > 0:
                self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
            self.merged = False
    
    def eval(self):
        nn.Linear.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            # self.merged= False
            if self.r > 0:
                self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
            self.merged = True

    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
           # self.merged= False
            result = nn.Embedding.forward(self, x)
            if self.r > 0:
                after_A = F.embedding(
                    x, self.lora_A.T, self.padding_idx, self.max_norm,
                    self.norm_type, self.scale_grad_by_freq, self.sparse
                )  # W0x + BAx
                result += (after_A @ self.lora_B.T) * self.scaling
            return result
        else:
            return nn.Embedding.forward(self, x)
Class Linear

kaiming_uniform_ kaiming 初始化
fin in fin out 含义

https://towardsdatascience.com/understand-kaiming-initialization-and-implementation-detail-in-pytorch-f7aa967e9138

因为加了fin in fin out 参数的原因,比之前的embedding层多了一个def T

class Linear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if self.merge_weights and self.merged:
            # Make sure that the weights are not merged
            if self.r > 0:
                self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
            self.merged = False
    
    def eval(self):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            if self.r > 0:
                self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
            self.merged = True

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.T if self.fan_in_fan_out else w

            #Merge = False
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            if self.r > 0:
                result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
            return result
        #Merge =True
        else:
            return F.linear(x, T(self.weight), bias=self.bias)
Class MergedLinear

这个针对self- attention模块使用

class MergedLinear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self, 
        in_features: int, 
        out_features: int, 
        r: int = 0, 
        lora_alpha: int = 1, 
        lora_dropout: float = 0.,
        enable_lora: List[bool] = [False],
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)
        assert out_features % len(enable_lora) == 0, \
            'The length of enable_lora must divide out_features'
        
        #一个true false list
        self.enable_lora = enable_lora
        self.fan_in_fan_out = fan_in_fan_out
        # Actual trainable parameters
        if r > 0 and any(enable_lora):
            self.lora_A = nn.Parameter(
                self.weight.new_zeros((r * sum(enable_lora), in_features)))
            self.lora_B = nn.Parameter(
                self.weight.new_zeros((out_features // 
                                       len(enable_lora) * sum(enable_lora), r))
            ) # weights for Conv1D with groups=sum(enable_lora) 计算有几个True
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False

            #因为针对像attention计算中需要Wq Wk Wv 三种linear merge 一起的情况
            # Compute the indices
            # input (out_features) output (len(enable_lora) , out_features/len(enable_lora))
            self.lora_ind = self.weight.new_zeros(
                (out_features, ), dtype=torch.bool
            ).view(len(enable_lora), -1)
            #对应的那一行就设为True
            self.lora_ind[enable_lora, :] = True
            self.lora_ind = self.lora_ind.view(-1)
        self.reset_parameters()
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

    def reset_parameters(self):
        nn.Linear.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)

    def zero_pad(self, x):
        result = x.new_zeros((*x.shape[:-1], self.out_features))
        result = result.view(-1, self.out_features)
        result[:, self.lora_ind] = x.reshape(
            -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
        )
        return result.view((*x.shape[:-1], self.out_features))

    def train(self, mode: bool = True):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if self.merge_weights and self.merged:
            # Make sure that the weights are not merged
            if self.r > 0 and any(self.enable_lora):
                delta_w = F.conv1d(
                    self.lora_A.data.unsqueeze(0), 
                    self.lora_B.data.unsqueeze(-1), 
                    groups=sum(self.enable_lora)
                ).squeeze(0)
                self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
            self.merged = False
    
    def eval(self):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        nn.Linear.eval(self)
        if self.merge_weights and not self.merged:
            # Merge the weights and mark it
            if self.r > 0 and any(self.enable_lora):
                delta_w = F.conv1d(
                    self.lora_A.data.unsqueeze(0), 
                    self.lora_B.data.unsqueeze(-1), 
                    groups=sum(self.enable_lora)
                ).squeeze(0)
                self.weight.data += self.zero_pad(T(delta_w * self.scaling))
            self.merged = True

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.T if self.fan_in_fan_out else w
        if self.merged:
            return F.linear(x, T(self.weight), bias=self.bias)
        else:
            result = F.linear(x, T(self.weight), bias=self.bias)
            if self.r > 0:
                after_A = F.linear(self.lora_dropout(x), self.lora_A)
                after_B = F.conv1d(
                    after_A.transpose(-2, -1), 
                    self.lora_B.unsqueeze(-1), 
                    groups=sum(self.enable_lora)
                ).transpose(-2, -1)
                result += self.zero_pad(after_B) * self.scaling
            return result

参考

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS
Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning
Measuring the Intrinsic Dimension of Objective Landscapes.
peft/tuners/lora.py
microsoft/LoRA
LoRA:大模型的低秩适配-最近大火的lora到底是什么东西?为啥stable diffusion和开源ChatGPT复现都在用?
论文阅读:LORA-大型语言模型的低秩适应
LLaMA模型详解-当前开源ChatGPT复现中使用最多的基础模型
论文速读:LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS
论文阅读:LORA-大型语言模型的低秩适应
微软LoRA: Low-Rank Adaptation of Large Language Models 代码解读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/466700.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一篇简单的文章带你玩转SpringBoot 之定时任务详解

序言 使用SpringBoot创建定时任务非常简单,目前主要有以下三种创建方式: 一、基于注解(Scheduled)二、基于接口(SchedulingConfigurer) 前者相信大家都很熟悉,但是实际使用中我们往往想从数据库中读取指定时间来动态…

Linux网络编程TCP连接的建立和终止

文章目录 前言一、TCP的三路握手二、TCP连接终止总结 前言 本篇文章将讲解TCP的连接的建立和终止,主要就是讲解TCP的三路握手和TCP连接断开内部发生的一些机制和事件。 一、TCP的三路握手 TCP三路握手所交换的三个分节: (1)服务器必须准备好接受外来…

C++题解之对顶堆:中位数

中位数 题目链接:洛谷P1168 中位数 题目描述 给定一个长度为 N N N 的非负整数序列 A A A,对于前奇数项求中位数。 输入格式 第一行一个正整数 N N N。 第二行 N N N 个正整数 A 1 … N A_{1\dots N} A1…N​。 输出格式 共 ⌊ N 1 2 ⌋ …

【是C++,不是C艹】 省缺参数 | 函数重载 | 内联函数

💞💞欢迎来到 Claffic 的博客 💞💞 👉 专栏:《是C,不是C艹》👈 前言: 上期,我带大家给C打了招呼,捎带着认识了命名空间和输入输出,那…

LeetCode——链表简单题题解

83. 删除排序链表中的重复元素 题目描述 给定一个已排序的链表的头 head , 删除所有重复的元素,使每个元素只出现一次 。返回 已排序的链表 。 输入:head [1,1,2] 输出:[1,2] 解题思路:用一个指向节点类型的指针保…

Vscode配置C/C++开发环境

下载Vscode进行安装。 下载MinGW-W64 GCC最新版本, 选择x86_64-win32-seh进行下载。解压放入自定义目录(英文路径)后,添加$:\mingw64\bin到系统Path环境变量。安装C/C插件。 在.c后缀文件中按Ctr Shift P,选择C/C …

Windows系统文件被faust勒索病毒加密勒索病毒解密恢复,电脑中病毒了怎么修复?

恶意软件的攻击已经让电脑用户变得更加谨慎了。在最近的一波攻击中,faust勒索病毒已经对使用Windows系统的计算机造成了广泛的破坏。该病毒利用加密技术锁定用户的文件,只有在支付一定数额的赎金后才会解锁这些文件。如果你的计算机中也受到了这种勒索病…

MaxScript编写bone转换biped工具

一、制作转换工具的缘由 大家好,我是阿赵。我经常从各种渠道得到了一些角色模型,这些模型得到之后,会发现是带有蒙皮和骨骼,甚至带有动作的。   不过这些资源很多都是从游戏截取出来的,导入到3DsMax之后,…

信息安全复习十:Web与电子商务安全

一、章节梗概 1.信息安全的学科内容 2.Web和电子商务安全问题提出 3.安全套接字协议SSL与传输层安全协议TLS 4.安全电子交易(SET)简要介绍 复习: 密码学内容:对称密钥密码、公开密钥密码、报文鉴别 PKI:数字签名、数字证书、信任关系 身份认…

MES管理系统助力企业数字化与实体经济实现“数实融合“

中国企业评估协会日前公布了2022年度“中国新经济500强”的名单。在这些企业中,先进制造业企业的比例超过了半数,尤其是互联网与现代信息技术服务业、新能源产业、新型生活性服务业等,在这些行业中占据了重要地位。由此可以看出,绿…

C/C++每日一练(20230427) 二叉树专场(5)

目录 1. 从中序与后序遍历序列构造二叉树 🌟🌟 2. 从先序与中序遍历序列构造二叉树 🌟🌟 3. 二叉树展开为链表 🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每…

微信小程序php+python+nodejs+vue 高校工资管理系统

在线公益知识练习及测试系统是随着计算机技术和互联网技术的发展而产生的一种的新的练习及测试模式,与传统的纸质化练习及测试不同,在线系统提高了练习或测试的效率,减少了纸张的浪费,减轻了教师的评卷压力,同时也为参…

易观千帆 | 金融机构如何保证用户体验长期可持续?

易观:用户体验正逐渐成为金融机构的命脉。 数字经济时代的到来,金融机构面临着来自内部和外部的双重压力。一方面,互联网金融企业凭借强大的技术能力以及人才优势,通过互联网运营的模式迅速响应客户需求,吸引了大量用户…

自定义Feign日志

文章目录 开篇定位Feign是怎么打印日志的?自定义FeignLogger实现 开篇 在上一篇Feign打印日志文章中,已经成功打印了FeignClient请求服务的日志信息,但是默认打印的日志太过零散,不是我们想要的。怎么能自定义日志打印的格式和内…

系统分析师《企业信息化战略与实施》高频知识点二

企业信息化战略与实施---系统建模 组织结构是一个企业内部部门的划分以及相互之间的关系,每个企业都有自己的组织结构图,它将企业分成若干部分,标明行政隶属关系。组织结构图是一种树结构,树的分支是根据上下级和行政隶属关系绘制…

二叉树 + 技巧

题目难度备注2471. 逐层排序二叉树所需的最少操作数目1635BFS 置换环 离散化2641. 二叉树的堂兄弟节点 II1677BFS 文章目录 周赛二叉树问题[2471. 逐层排序二叉树所需的最少操作数目](https://leetcode.cn/problems/minimum-number-of-operations-to-sort-a-binary-tree-by-l…

【AI绘画】云服务器部署stable-diffusion-webui保姆级教程

1.背景 之前给大家写过Mac苹果笔记本上部署stable-diffusion-webui的教程,知乎链接: 【奶奶看了也不会】AI绘画 Mac安装stable-diffusion-webui绘制AI妹子保姆级教程 但是安装过程就花了一天的时间,各种问题处理起来真是苦不堪言。。。而且…

如何解决生产缺料问题?

对于一个生产型企业来说,生产缺料是管理中的疑难问题之一。 生产缺料导致的最直接的危害就有两个: 第一就是不能准时交货,降低企业信用,情况严重可能导致根据合同赔款,甚至丢失客户。 第二个危害就是增加了生产成本…

部署LVS-DR 集群及实验

一、LVS-DR工作原理 LVS-DR(Linux Virtual Server Director Server)工作模式,是生产环境中最常用的一种工作模式。 #①LVS-DR 模式,Director Server 作为群集的访问入口,不作为网关使用; #②节点 Directo…

ueditor富文本编辑器上传木马图片或木马文件漏洞

漏洞分析: ueditor插件目录一般都自带index.html文件(完整demo文件),这个文件原来是用来测试插件功能的,但粗心的程序员们,开发好项目后,都忘记删除这个demo文件了,导致黑客可以很轻…