ChatGLM-LLaMA-chinese-insturct 学习记录(含LoRA的源码理解)

news2025/2/22 12:02:56



前言

介绍:探索中文instruct数据在ChatGLM, LLaMA等LLM上微调表现,结合PEFT等方法降低资源需求。
Github: https://github.com/27182812/ChatGLM-LLaMA-chinese-insturct
补充学习:https://kexue.fm/archives/9138


一、实验记录

1.1 环境配置

优雅下载hugging face模型和数据集

conda update -n base -c defaults conda
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
apt-get install git-lfs
git lfs install

git clone [模型|数据集|地址]

配置conda 环境

conda env create -f env.yml -n bab
conda activate bab
pip install git+https://github.com/huggingface/peft.git

数据集
belle数据集 和 自己收集的中文指令数据集
指令数据集

{
	"context": Instruction:[举一个使用以下对象的隐喻示例]\nInput:[星星]\nAnswer:, 
	"target": Answer:星星是夜空中闪烁的钻石。
}

1.2 代码理解

函数:gradient_checkpointing_enable
如何理解 gradient_checkpoint, 时间换空间,使得模型显存占用变小,但训练时长增加
PEFT的相关介绍
大模型训练——PEFT与LORA介绍
你也可以动手参数有效微调:LoRA、Prefix Tuning、P-Tuning、Prompt Tuning

1.2.1 LoRA

在这里插入图片描述
对于左右两个部分,右侧看起来像是左侧原有矩阵W WW的分解,将参数量从d ∗ d 变成了d ∗ r + d ∗ r ,在
r < < d的情况下,参数量就大大地降低了。LORA保留了原来的矩阵W,但是不让W参与训练,所以需要计算梯度的部分就只剩下旁支的A和B两个小矩阵。
蓝色部分的目标函数为:
在这里插入图片描述
加入LoRA之后:
在这里插入图片描述
但是相应地,引入LORA部分的参数,并不会在推理阶段加速(不是单纯的橙色部分进行计算),因为在前向计算的时候, 蓝色部分还是需要参与计算的,而Θ部分是凭空增加了的参数,所以理论上,推理阶段应该比原来的计算量增大一点。
在这里插入图片描述
技术细节:
在这里插入图片描述
α 可以理解我们在调整lr, α/r 实在缩放蓝色部分的输出,有助于减少训练的超参数

相关参数:
在这里插入图片描述
那么如何使用PEFT的LoRA

from peft import get_peft_model, LoraConfig, TaskType

peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,   
        r=finetune_args.lora_rank,
        lora_alpha=32,
        lora_dropout=0.1,
    )
    
model = get_peft_model(model, peft_config)

其中TaskType可以设置多种任务

class TaskType(str, enum.Enum):
    SEQ_CLS = "SEQ_CLS"   常规分类任务
    SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" seq2seq任务
    CAUSAL_LM = "CAUSAL_LM"  LM任务
    TOKEN_CLS = "TOKEN_CLS"  token的分类任务:序列标注之类的

参数解释:

inference_mode = Whether to use the Peft model in inference mode.

根据苏神的介绍,LST的效果应该是优于LoRA的:
在这里插入图片描述
每层当中都有分支,可以理解为LoRA是LST的超简化版本

    def __init__(self, model, config, adapter_name):
        super().__init__()
		...
        self.add_adapter(adapter_name, self.peft_config[adapter_name])

    def add_adapter(self, adapter_name, config=None):
		...
        self._find_and_replace(adapter_name)
		...
        mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
        if self.peft_config[adapter_name].inference_mode:
            _freeze_adapter(self.model, adapter_name)

核心类在

 def _find_and_replace(self, adapter_name):
        ...
        # 遍历整个需要训练的模型的名字,这个模型你可以理解为一个字典,拿出所有的key
        key_list = [key for key, _ in self.model.named_modules()]
        for key in key_list:
        	# 找到所有qkv的key
            if isinstance(lora_config.target_modules, str):
                target_module_found = re.fullmatch(lora_config.target_modules, key)
            else:
                target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
                        ...
                        # 然后对于每一个找到的目标层,创建一个新的lora层
                        # 注意这里的Linear是在该py中新建的类,不是torch的Linear
                        new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)
                    self._replace_module(parent, target_name, new_module, target)

replace_modul把原来的weight和bias赋给新创建的module,然后再分配到指定的设备上

    def _replace_module(self, parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)
        new_module.weight = old_module.weight
        if old_module.bias is not None:
            new_module.bias = old_module.bias
        if getattr(old_module, "state", None) is not None:
            new_module.state = old_module.state
            new_module.to(old_module.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():
            if "lora_" in name:
                module.to(old_module.weight.device)

merge\ forward部分

    def merge(self):
        if self.active_adapter not in self.lora_A.keys():
            return
        if self.merged:
            warnings.warn("Already merged. Nothing to do.")
            return
        if self.r[self.active_adapter] > 0:
            self.weight.data += (
                transpose(
                    self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight,
                    self.fan_in_fan_out,
                )
                * self.scaling[self.active_adapter]
            )
            self.merged = True
	
	def forward(self, x: torch.Tensor):
        previous_dtype = x.dtype

        if self.active_adapter not in self.lora_A.keys():
            return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        if self.disable_adapters:
            if self.r[self.active_adapter] > 0 and self.merged:
                self.unmerge()
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        elif self.r[self.active_adapter] > 0 and not self.merged:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

            x = x.to(self.lora_A[self.active_adapter].weight.dtype)

            result += (
                self.lora_B[self.active_adapter](
                    self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
                )
                * self.scaling[self.active_adapter]
            )
        else:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        result = result.to(previous_dtype)

        return result

评估的过程中,需要将lora部分的weight加到linear层原本的weight中,not self.merged是状态的记录,也就是说,如果设置了需要融合,而当前状态没有融合的话,就把lora部分的参数scale之后加上去,并且更新self.merged状态;

训练的过程中,确保linear本身的weights是没有经过融合过的

1.4 实验结果

chatglm-6b loss的下降不是特别多,3epoch效果也不是特别的明显,最近看到很多人反馈,不管是基于lora还是ptuning对原本的模型效果还是影响很大


二、总结

如果要基于大语言模型的FT,至少需要足够的显存,和语料,最好是将新的语料和原本的语料一起进行SFT

  • sft的原理还没有弄明白
  • 显存还需要扩充,使用deepspeed框架进行full FT,有资源谁还回去lora,ptuning呢?
  • 多轮的数据集还没有
  • 这个仓库的数据集还是,单轮的指令数据集,并没有涉及到多轮
  • 即使是官方的仓库也只是构造了多轮的训练脚本,数据集并没有提供
  • llama不跑了,只是换了一个模型而已

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/501894.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Win10任务栏透明,3个超好用解决方法!

案例&#xff1a;win10任务栏透明怎么办&#xff1f; 【我的电脑不知道为什么任务栏突然就变透明了&#xff0c;现在不知道该如何解决&#xff0c;遇到这种情况应该怎么办呀&#xff1f;】 Win10任务栏是Windows 10操作系统的一部分&#xff0c;通常默认为不透明。然而&#…

asp.net+sqlserver企业公司进销存管理系统

基于WEB的进销存管理系统主要企业内部提供服务&#xff0c;系统分为管理员&#xff0c;和员工2部分。 在本基于WEB的进销存管理系统中分为管理员&#xff0c;和普通用户2中模式&#xff0c;其中管理人员主要是对企业内商品类型。商品信息商品的出入库信息&#xff0c;以及员工…

堆栈溢出一般是什么原因?

堆栈是一个在计算机科学中经常使用的抽象数据类型。堆栈中的物体具有一个特性&#xff1a; 最后一个放入堆栈中的物体总是被最先拿出来&#xff0c; 这个特性通常称为后进先出(LIFO)队列。 堆栈中定义了一些操作。 两个最重要的是PUSH和POP。 PUSH操作在堆栈的顶部加入一 个元素…

MySQL深度分页

1. 什么是深度分页 深度分页问题的本质是在 MySQL 数据库中&#xff0c;通过 LIMIT 和 OFFSET 关键字进行分页时&#xff0c;MySQL 需要在每次查询时扫描整张表&#xff0c;直到找到当前页的数据。这种查询方式需要进行大量的磁盘 I/O 和内存操作&#xff0c;导致查询效率非常…

Microsoft Edge新功能测评体验

Microsoft Edge使用体验 Microsoft Edge是一款现代化的浏览器&#xff0c;它拥有众多功能和强大的性能&#xff0c;为用户带来更加流畅的浏览体验。 Edge最近推出了分屏功能&#xff0c;支持一个窗口同时显示两个选项卡&#xff0c;这可以大大提高生产力和多任务处理能力。 一…

什么样的蓝牙耳机佩戴舒适?蓝牙耳机佩戴舒适度排名

越来越多的人开始使用运动蓝牙耳机了&#xff0c;不仅仅是因为蓝牙耳机的它无耳机线的束缚&#xff0c;日常还很便携&#xff0c;市面上的蓝牙耳机质量参差不齐&#xff0c;有些佩戴舒适度也比较差&#xff0c;下面整理了几款评分还不错的几款蓝牙耳机。 一、南卡小音舱Lite2蓝…

第四十四章 Unity 滑动条 (Slider) UI

本章节我们介绍滑动条 (Slider)&#xff0c;它允许用户通过拖动鼠标从预定范围中选择数值。首先&#xff0c;我们点击菜单栏“GameObject”->“UI”->“Slider”&#xff0c;调整其位置&#xff0c;最终效果如下 我们发现滑动条 (Slider)下面有三个子游戏对象Background&…

如何使DocuWare成为所有部门的数据中心

如何使DocuWare成为所有部门的数据中心 自动化流程通常需要多个部门的数据&#xff0c;而各个部门通常使用不同的软件。 DocuWare可帮助您集中管理所有信息&#xff0c;并将信息应用于您的进程和工作流程当中。 您的公司使用不同的系统&#xff0c;但您又想将这些数据整合在一…

手敲Mybatis(十)-完善ORM框架支持增删改查

我们把基本的功能都完成了&#xff0c;解析xml、构建映射代理、执行sql&#xff0c;解析处理结果&#xff0c;目前这些只支持查询&#xff0c;我们还差添加下增删改的功能&#xff0c;本章节就来完善下增删改&#xff0c;其实本章节比较简单&#xff0c;因为之前的每个章节都已…

这一篇LiveData掉不掉价(使用->原理分析->粘性事件解决)

1. 简介 LiveData 是一种可观察的数据存储器类。与常规的可观察类不同&#xff0c;LiveData 具有生命周期感知能力&#xff0c;意指它遵循其他应用组件&#xff08;如 activity、fragment 或 service&#xff09;的生命周期。这种感知能力可确保 LiveData 仅更新处于活跃生命周…

数据备份系列:Rsync 备份详解(二)

一、Rsync Cron 场景使用 在对数据备份要求实时性不高的情况下&#xff0c;可优先考虑该场景&#xff0c;选择一个合适的时间&#xff0c;对数据进行定时远程增量同步。 在《数据备份系列&#xff1a;Rsync 备份详解&#xff08;一&#xff09;》中我们已经对服务搭建以及远程…

【虚幻引擎】UE5数据表格导入

数据表 顾名思义&#xff0c;DataTable是一种表格&#xff0c;里面装着大量游戏相关的数据&#xff0c;这些数据会按照其含义和用途分类&#xff0c; 其中&#xff0c;数据字段可以是UObject的任意有效属性&#xff08;包括资产的引用信息&#xff09;。设计师若要将 CSV文件导…

c++类的静态变量、静态函数 笔记

正文&#xff1a; 1、看下面这个是一个常规的类 #include <iostream> #include <windows.h> using namespace std; class BOX{int callsNum1;public:BOX(){callsNum;};int fun(){return callsNum;}; }; // int BOX::callsNum1;// 程序的主函数 int main() {SetCo…

【某区护网】从外网打点到拿下域控

目录 web打点 反弹shell与权限维持 主机信息收集与反向代理 攻击域控 前端时间刚结束了攻防演练活动&#xff0c;其中一项成果为拿下某集团域控制器权限&#xff0c;直接控制域内主机5000多台。以下为攻击过程的粗略记录&#xff0c;整体来说还是比较容易。 web打点 接到…

N1Book-第一章Web入门-任意文件读取漏洞-afr_2

本题为Nu1L团队编著的《从0到1&#xff1a;CTFer成长之路》配套题目。来源网站&#xff1a;https://book.nu1l.com/ 经过多方查阅资料&#xff0c;发现题目是&#xff0c;由于Nginx配置不当产生了目录穿越漏洞。本题使用的是OpenResty&#xff0c;而OpenResty是基于Nginx与Lua实…

门诊自助打印机可以办理哪些业务呢?

自助打印机可以办理以下业务&#xff1a; 检验报告单打印&#xff1a;患者可以通过医院验单自助打印机自主打印检验报告单&#xff0c;避免了等待时间&#xff0c;提高了医院的服务效率&#xff1b;检验报告查询&#xff1a;患者可以通过医院验单自助打印机查询自己的检验报告…

HHDBCS便捷功能简介

1. 连接管理 使用数据库时&#xff0c;不可避免的要建立很多个连接。 如果单纯用命令执行切换用户的话&#xff0c;实在是一件麻烦事。 那么这种麻烦事就交给HHDECS好了。 点击连接管理&#xff0c;一键切换。 而且能在不同数据库之间随意切换 2. 使用高级模式&#xff…

Linux环境安装iperf3(网络性能测试工具)

[rootlocalhost ]# yum search iperf 已加载插件&#xff1a;fastestmirror Loading mirror speeds from cached hostfile* base: mirrors.tuna.tsinghua.edu.cn* extras: mirrors.huaweicloud.com* updates: mirrors.tuna.tsinghua.edu.cnN/S matched: iperf iperf3-devel.i6…

数据分析示例-python

数据分析示例-python 今天呢&#xff0c;博主把之前做过的一个小课题拿出来展示一下&#xff0c;当然这个课题呢做的工作量很大&#xff0c;也用到了很多可以参考的技术和代码&#xff0c;做数据分析工作的可以尝试学习学习。 这篇博客&#xff0c;我们先从数据集开始介绍。 对…

GSAP - 一款基于 JavaScript 的 web 动画库,简单几行代码就能写出丝滑流畅、高性能的动画效果

使用简单&#xff0c;但做出来的动画非常丝滑&#xff0c;也能实现很多专业的动画效果&#xff0c;推荐给大家。 关于 GSAP GSAP 的全名是 GreenSock Animation Platform&#xff0c;项目诞生非常早&#xff0c;远在 flash 繁荣的时代就存在&#xff0c;一直发展到今天已经是…