改变LoRA的初始化方式,北大新方法PiSSA显著提升微调效果

news2024/12/23 18:00:51

    ChatGPT狂飙160天,世界已经不是之前的样子。
新建了免费的人工智能中文站https://ai.weoknow.com
新建了收费的人工智能中文站https://ai.hzytsoft.cn/

更多资源欢迎关注


随着大模型的参数量日益增长,微调整个模型的开销逐渐变得难以接受。

为此,北京大学的研究团队提出了一种名为 PiSSA 的参数高效微调方法,在主流数据集上都超过了目前广泛使用的 LoRA 的微调效果。

图片

  • 论文: PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models

  • 论文链接: https://arxiv.org/pdf/2404.02948.pdf

  • 代码链接: https://github.com/GraphPKU/PiSSA

如图 1 所示,PiSSA (图 1c) 在模型架构上和 LoRA [1] 完全一致 (图 1b),只是初始化 Adapter 的方式不同。LoRA 使用高斯噪声初始化 A,使用 0 初始化 B。而 PiSSA 使用主奇异值和奇异向量 (Principal Singular values and Singular vectors) 来初始化 Adapter 来初始化 A 和 B。

图片

图 1)从左到右依次为全参数微调、LoRA、以及 PiSSA。蓝色代表冻结的参数,橘黄色代表可训练参数及它们的初始化方式。相比全参数微调,LoRA 和 PiSSA 都大幅节省了可训练参数量。对于相同输入,这三种方法的初始输出完全相等。然而,PiSSA 冻结模型的次要成分,直接微调主成分(前 r 个奇异值和奇异向量);而 LoRA 可看作冻结模型的主要部分,而去微调 noise 部分。

在不同的任务上对比 PiSSA、LoRA 的微调效果

研究团队使用 llama 2-7B、Mistral-7B 以及 Gemma-7B 作为基础模型,通过微调提升它们的数学、代码和对话能力。其中包括:在 MetaMathQA 上训练,在 GSM8K 和 MATH 数据集上验证模型的数学能力;在 CodeFeedBack 上训练,在 HumanEval 和 MBPP 数据集上验证模型的代码能力;在 WizardLM-Evol-Instruct 上训练,在 MT-Bench 上验证模型的对话能力。从下表的实验结果可以看出,使用相同规模的可训练参数,PiSSA 的微调效果显著超越了 LoRA,甚至超越了全参数微调。

图片

对比 PiSSA、LoRA 在不同的可训练参数量下微调的效果

研究团队在数学任务上对模型的可训练参数量和效果之间的关系进行消融实验。从图 2.1 发现在训练初期,PiSSA 的训练 loss 下降特别快,而 LoRA 存在不下降,甚至略有上升的阶段。此外,PiSSA 的训练 loss 全程低于 LoRA,说明对训练集拟合得更好;从图 2.2、2.3、2.4 可以看出在每种 setting 下,PiSSA 的 loss 始终比 LoRA 低,准确率始终比 LoRA 高,PiSSA 能够使用更少的可训练参数追赶上全参数微调的效果。

图片

图 2.1) 当秩为 1 时 PiSSA、LoRA 在训练过程中的 loss。每幅图的右上角是前 100 步迭代放大的曲线。其中 PiSSA 用橙色线表示,LoRA 用蓝色线表示,全参数微调用绿线展示了最终的 loss 作为参考。秩为 [2,4,8,16,32,64,128] 时的现象与此一致,详见文章附录。

图片

图 2.2)使用秩为 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 的最终 training loss。

图片

图 2.3)使用秩为 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 微调的模型在 GSM8K 上的准确率。

图片

图 2.4)使用秩为 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 微调的模型在 MATH 上的准确率。

PiSSA 方法详解

受到 Intrinsic SAID [2]“预训练大模型参数具有低秩性” 的启发,PiSSA 对预训练模型的参数矩阵

图片

进行奇异值分解,其中前 r 个奇异值和奇异向量用来初始化适配器 (adapter) 的两个矩阵

图片

图片

图片

;剩余的奇异值和奇异向量用来构造残差矩阵

图片

,使得

图片

。因此,适配器中的参数包含了模型的核心参数,而残差矩阵中的参数是修正参数。通过微调参数量较小的核心适配器 A、B,冻结参数量较大的残差矩阵

图片

,就达成了用很少的参数近似全参数微调的效果。

尽管同样受到 Intrinsic SAID [1] 启发,PiSSA 和 LoRA 背后的原理却截然不同。

LoRA 认为大模型微调前后矩阵的变化 △W 具有很低的本征秩 r,因此通过

图片

图片

相乘得到的低秩矩阵来模拟模型的变化 △W。初始阶段,LoRA 使用高斯噪声初始化 A,使用 0 初始化 B,因此

图片

,以此保证模型初始能力没有变化,并微调 A 和 B 实现对 W 进行更新。与此相比,PiSSA 不关心 △W,而是认为 W 具有很低的本征秩 r。因此直接对 W 进行奇异值分解,分解成主成分 A、B,以及残差项

图片

,使得

图片

。假设 W 的奇异值分解为

图片

,A、B 使用 SVD 分解后奇异值最大的 r 个奇异值、奇异向量进行初始化:

图片

残差矩阵使用其余的奇异值、奇异向量进行初始化:

图片

PiSSA 直接对 W 的低秩主成分 A、B 进行微调,冻结次要的修正项。相比 LoRA 用高斯噪声以及 0 初始化适配器参数、冻结核心模型参数,PiSSA 收敛更快、效果更好。

PiSSA 的发音类似 “披萨”(pizza)--- 如果把整个大模型类比为一个完整的披萨,PiSSA 切掉其中一角,而且是馅料最丰富的一角(主奇异值、奇异向量),重新烘焙(在下游任务上微调)成喜欢的口味。

由于 PiSSA 采用了和 LoRA 完全相同的架构,其可以作为 LoRA 的一种可选初始化方式,在 peft 包中很方便的进行修改和调用 (如以下代码所示)。相同的架构也使得 PiSSA 继承了大多数 LoRA 的优点,如:对残差模型使用 4bit 量化 [3],减小训练开销;微调完成后适配器能合并进残差模型,不改变推理过程的模型架构;无需分享完整模型参数,只需要分享参数量很少的 PiSSA 模块,使用者直接加载 PiSSA 模块就能自动进行奇异值分解以及赋值;一个模型可以同时使用多个 PiSSA 模块等等。一些对 LoRA 方法的改进,也能与 PiSSA 进行结合:比如不固定每层的秩,通过学习找到最佳的秩 [4];用 PiSSA 指导的更新 [5],从而突破秩的限制等等。

# 在 peft 包中 LoRA 的初始化方式后面增加了一种 PiSSA 初始化选项:
if use_lora:
  nn.init.normal_(self.lora_A.weight, std=1 /self.r)
  nn.init.zeros_(self.lora_B.weight) 
elif use_pissa:
  Ur, Sr, Vr = svd_lowrank (self.base_layer.weight, self.r, niter=4) 
  # 注意:由于 self.base_layer.weight 的维度是 (out_channel,in_channel, 所以 AB 的顺序相比图示颠倒了一下)
  self.lora_A.weight = torch.diag (torch.sqrt (Sr)) @ Vh.t ()
  self.lora_B.weight = Ur @ torch.diag (torch.sqrt (Sr)) 
  self.base_layer.weight = self.base_layer.weight - self.lora_B.weight @ self.lora_A.weight

对比高中低奇异值微调效果实验

为了验证使用不同大小奇异值、奇异向量初始化适配器对模型的影响,研究人员分别使用高、中、低奇异值初始化 LLaMA 2-7B、Mistral-7B-v0.1、Gemma-7B 的适配器,然后在 MetaMathQA 数据集上进行微调,实验结果展示在图 3 中。从图中可以看出,使用主要奇异值初始化的方法训练损失最小,在 GSM8K 和 MATH 验证集上的准确率更高。这一现象验证了微调主要奇异值、奇异向量的有效性。

图片

图 3)从左到右依次为训练 loss、在 GSM8K 上的准确率、在 MATH 上的准确率。其中蓝色表示最大奇异值、橙色表示中等奇异值、绿色表示最小奇异值。

快速奇异值分解

PiSSA 继承了 LoRA 的优点,使用起来方便,效果超越 LoRA。代价是在初始化阶段,需要对模型进行奇异值分解。虽然仅需要在初始化时分解一次,但是仍然可能需要几分钟甚至几十分钟的开销。因此,研究人员使用一种快速奇异值分解 [6] 方法替代标准的 SVD 分解,通过下表的实验可以看出,仅需几秒钟的时间,就能逼近标准 SVD 分解的训练集拟合效果。其中 Niter 表示迭代次数,Niter 越大,时间越久但是误差越小。Niter = ∞表示标准 SVD。表格中的平均误差表示快速奇异值分解与标准 SVD 得到的 A、B 之间的平均 L_1 距离。

图片

总结与展望

本工作对预训练模型的权重进行奇异值分解,通过将其中最重要的参数用于初始化一个名为 PiSSA 的适配器,微调这个适配器来近似微调完整模型的效果。实验表明,PiSSA 比 LoRA 收敛更快,最终效果更好,唯一的代价仅是需要几秒的 SVD 初始化过程。

那么,您愿意为了更好的训练效果,多花几秒钟时间,一键更改 LoRA 的初始化为 PiSSA 吗?

    ChatGPT狂飙160天,世界已经不是之前的样子。
新建了免费的人工智能中文站https://ai.weoknow.com
新建了收费的人工智能中文站https://ai.hzytsoft.cn/

更多资源欢迎关注


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1590252.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RestTemplate—微服务远程调用—案例解析

简介:总结来说,微服务之间的调用方式有多种,选择哪种方式取决于具体的业务需求、技术栈和架构设计。RESTful API和HTTP客户端是常见的选择,而Feign和Ribbon等辅助库可以简化调用过程。RPC和消息队列适用于特定的场景,如…

FPGA - 以太网UDP通信(三)

一,引言 前文链接:FPGA - 以太网UDP通信(一) FPGA - 以太网UDP通信(二) 在以上文章中介绍了以太网简介,以太网UDP通信硬件结构,以及PHY芯片RGMII接口-GMII接口转换逻辑&#xff0c…

Node.js从基础到高级运用】二十三、Node.js中自动重启服务器

引言 在Node.js开发过程中,我们经常需要修改代码后重启服务器来应用这些更改。手动重启不仅效率低下,而且会打断开发流程。幸运的是,有一些工具可以帮助我们自动化这个过程。本文将介绍如何使用nodemon来实现Node.js服务器的自动重启。 什么是…

清楚明了的凸松弛最优潮流!基于混合整数二阶锥规划的主动配电网最优潮流研究程序代码!

前言 最优潮流(optimal power flow,OPF)问题,是电力系统中最常见、最基础的一类优化问题。在满足基尔霍夫定律、线路容量约束以及运行安全约束等电力网络物理约束的前提下,OPF问题旨在寻找一个最优的潮流稳态工作点,使得在该工作…

【LAMMPS学习】八、基础知识(2.5)恒压器

8. 基础知识 此部分描述了如何使用 LAMMPS 为用户和开发人员执行各种任务。术语表页面还列出了 MD 术语,以及相应 LAMMPS 手册页的链接。 LAMMPS 源代码分发的 examples 目录中包含的示例输入脚本以及示例脚本页面上突出显示的示例输入脚本还展示了如何设置和运行各…

WebSocket一篇讲清楚

文章目录 WebSocket简介WebSocket与HTTP的区别WebSocket的工作原理WebSocket的应用场景WebSocket的使用WebSocket 属性WebSocket 事件WebSocket 方法 WebSocket的心跳机制WebSocket 的安全性和跨域问题如何处理?有哪些好用的客户端WebSocket第三方库总结 WebSocket简…

代码随想录图论

1. 所有可能的路径 class Solution:def allPathsSourceTarget(self, graph: List[List[int]]) -> List[List[int]]:def dfs(graph, result, path, root): #result 返回结果, path记录路径, root记录遍历到了第几个节点if root len(graph) - 1: #如果遍历到最后…

C#Winform使用扩展方法自定义富文本框(RichTextBox)字体颜色

实现效果 调用方法 rtxtLog.AppendTextColorful(richTextBox1,DateTime.Now.ToString(), Color.Red); 完整代码如下 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using Sys…

Java 基于微信小程序的汽车4S店客户管理小程序,附源码

博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&#x1f3…

SpringCloud框架 服务拆分和远程调用

数据库隔离避免耦合度过高,不同模块将自己的业务暴露为接口,供其他微服务调用 微服务远程调用技术Rest 在后端实现发送http请求 1.在启动类/配置类里注册RestTemplate启动对象 2.注入Bean对象使用

【力扣】17.04.消失的数字

这道题的题目意思就是从0-n中的数字中找出缺失的那一个,n是数组的长度,因此我的想法就是先将数组进行排序,往sort()里面一扔,完了以后看前一个与后一个之差中哪个不是等于1的,就求出来即可。 法…

去除pycharm运行pytest的默认参数--no-header --no-summary -q

进入pycharm设置(Settings),找到高级设置(Advanced Settings)—>Python–>Pytest:不添加"–no-header --no-summary -q"(Pytest:do not add “–no-header --no-summary -q”)

R语言计算:t分布及t检验

t分布理论基础 t分布也称Student’s t-distribution,主要出现在小样本统计推断中,特别是当样本量较小且总体标准差未知时,用于估计正态分布的均值。其定义基于正态分布和 X 2 X^{2} X2分布(卡方分布)。如果随机变量X服…

pytorch-多分类实战之手写数字识别

目录 1. 网络设计2. 代码实现2.1 网络代码2.2 train 3. 完整代码 1. 网络设计 输入是手写数字图片28x28,输出是10个分类0~9,有两个隐藏层,如下图所示: 2. 代码实现 2.1 网络代码 第一层将784降维到200,第二次使用…

Linux的学习之路:7、yum与git

摘要 本章主要是说一下yum和git的操作 目录 摘要 一、什么是yum 二、yum三板斧 1、list 2、install 3、remove 三、怎么创建仓库 四、git三板斧 1、add 2、commit 3、push 4、pull 五、思维导图 一、什么是yum YUM是Yellowdog Updater Modified的简称&#xf…

三方库移植之NAPI开发(三)通过IDE开发NAPI工程

在三方库移植之NAPI开发[1]—Hello OpenHarmony NAPI一文中,笔者开发的是一个rom包的napi工程。该工程需要编译烧录固件,C 的动态库会集成到开发板的ROM中。在本篇文章中,笔者使用三方库移植之NAPI开发[1]—Hello OpenHarmony NAPI中一样的he…

zabbix监控配置(添加主机、主机组和添加监控项等)

zabbix监控配置 文章目录 zabbix监控配置1.添加主机组2.添加主机(linux)3.添加主机(windows)4.监控项配置(通过模板添加)5.监控项配置(手动添加) 1.添加主机组 2.添加主机&#xff0…

【Github】PwGen用户友好的Web应用密码生成器

弱密码问题一直是网络安全领域的一个重大挑战。许多人为了方便记忆,倾向于使用简单、常见的密码,如“123456”、“password”或者他们的生日等,这些密码很容易被猜测或通过暴力破解方法攻破。弱密码的使用大大增加了账户被黑客入侵的风险&…

【深入解析spring cloud gateway】13 Reactive Feign的使用

问题引入 在gateway中如果使用feignClient的话,会报如下错误 java.lang.IllegalStateException: block()/blockFirst()/blockLast() are blocking, which is not supported in thread reactor-http-nio-3at reactor.core.publisher.BlockingSingleSubscriber.bloc…

C++实现一个自定义字符串类(string)

本博客将详细介绍如何在C中实现一个自定义的字符串类 string,这个类模仿了标准库中 std::string 的关键功能。这个过程将涵盖从声明到定义的每一步,重点介绍内存管理、操作符重载以及提供一些关键的实现细节。 首先:我们采用函数的声明与定义…