如何训练一个大模型:LoRA篇

news2025/1/10 21:33:44

目录

写在前面

一、LoRA算法原理

1.设计思想

2.具体实现

二、peft库

三、完整的训练代码

四、总结


写在前面

        现在有很多开源的大模型,他们一般都是通用的,这就意味着这些开源大模型在特定任务上可能力不从心。为了适应我们的下游任务,就需要对预训练模型进行微调。

        全参数微调有两个问题:在新的数据集上训练,会破坏大模型原来的能力,使其泛化能力急剧下降;而且现在的模型参数动辄几十亿上百亿,要执行全参数微调的话,他贵啊!!

        于是LoRA出现了, LoRA(Low-Rank Adaptation)是微软提出的一种参数有效的微调方法,可以降低微调占用的显存以及更轻量化的迁移。同时解决了上述两个问题,那它凭什么这么厉害?往下看吧。

一、LoRA算法原理

1.设计思想

        论文地址:https://arxiv.org/pdf/2106.09685

        模型是过参数化的,它们有更小的内在维度,模型主要依赖于这个低的内在维度(low intrinsic dimension)去做任务适配。假设模型在适配任务时参数的改变量是低秩的,由此引出低秩自适应方法lora,通过低秩分解来模拟参数的改变量,从而以极小的参数量来实现大模型的间接训练。

       上面那段话也许有点难以理解。简单来讲,LoRA是大模型的低秩适配器,或者就简单的理解为适配器,在图像生成中可以将lora理解为某种图像风格(比如SD社区中的各种漂亮妹子的lora,可插拔式应用,甚至组合式应用实现风格的融合)的适配器,在NLP中可以将其理解为某个任务的适配器(比如基于通用大模型训练的各个领域的专家大模型)。

2.具体实现

        LoRA的实现方式是在基础模型的线性变换模块(全连接、Embedding、卷积)旁边增加一个旁路,这个旁路是由两个小矩阵做内积得来的,两个小矩阵的中间维度,就是秩!!

        通过低秩分解(先降维再升维)来模拟参数的更新量。

        下面是LoRA的公式:

h = W_0x +\Delta Wx = W_0x + ((A \bigotimes B) * \alpha / r)x

       上面公式中x是这一层的输入,h是这一层的输出,W_0是基础模型的权重参数;A和B是两个小矩阵,A的输入和B的输出形状跟W_0一样,A的输出和B的输入一样,称为秩,秩一般很小,微调的所有“新知识”都保存在A和B里面\alpha /r是一个缩放系数,这个数越大,LoRA权重的影响就越大。

        下面就是经典的LoRA运算流程图:

        我们以ChatGLM的attention模块的query_key_value(是一个linear(4096, 12288))为例,描述一下流程,其中输入4096、输出12288,LoRA的秩是8:

        初始化时,lora_A采用高斯分布初始化,lora_B初始化为全0,保证训练开始时旁路为0矩阵;        

        训练时,原模型固定,只训练降维矩阵A和升维矩阵B;

        推理时需要做参数合并,就是将AB的内积(一个与基础模型形状一样的低秩矩阵)加到原参数上,这样不引入额外的推理延迟。对于上图的例子,lora_A与lora_B做内积,得到4096x1228的参数矩阵,然后与基础模型W相加就可以了。

        我们来算算需要训练多少参数,如果是全参数需要训练4096*12288=50331648个参数,LoRA需要训练4096*8+8*12288=131072,参数可是数量级的减少啊。

二、peft库

        Pytorch中peft库实现了LoRA算法,而且使用非常方便,我们以ChatGLM代码为例,看一下LoRA对ChatGLM模型做了什么,直接上代码:

from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModel, HfArgumentParser, TrainingArguments

from finetune import CastOutputToFloat, FinetuneArguments


def count_params(model):
    for name, param in model.named_parameters():
        print(name, param.shape)



def make_peft_model():
    # 初始化原模型
    model = AutoModel.from_pretrained(
        "THUDM/chatglm-6b", load_in_8bit=False, trust_remote_code=True, device_map="auto", local_files_only=True
    ).float()
    

    # 给原模型施加LoRA
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=True,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=['query_key_value'],
    )
    model = get_peft_model(model, peft_config).float()
    count_params(model)



if __name__ == '__main__':
    make_peft_model()

        输出如下:       

base_model.model.transformer.word_embeddings.weight torch.Size([130528, 4096])
base_model.model.transformer.layers.0.input_layernorm.weight torch.Size([4096])
base_model.model.transformer.layers.0.input_layernorm.bias torch.Size([4096])
base_model.model.transformer.layers.0.attention.query_key_value.base_layer.weight torch.Size([12288, 4096])
base_model.model.transformer.layers.0.attention.query_key_value.base_layer.bias torch.Size([12288])
base_model.model.transformer.layers.0.attention.query_key_value.lora_A.default.weight torch.Size([8, 4096])
base_model.model.transformer.layers.0.attention.query_key_value.lora_B.default.weight torch.Size([12288, 8])

base_model.model.transformer.layers.0.attention.dense.weight torch.Size([4096, 4096])
base_model.model.transformer.layers.0.attention.dense.bias torch.Size([4096])
base_model.model.transformer.layers.0.post_attention_layernorm.weight torch.Size([4096])
base_model.model.transformer.layers.0.post_attention_layernorm.bias torch.Size([4096])
base_model.model.transformer.layers.0.mlp.dense_h_to_4h.weight torch.Size([16384, 4096])
base_model.model.transformer.layers.0.mlp.dense_h_to_4h.bias torch.Size([16384])
base_model.model.transformer.layers.0.mlp.dense_4h_to_h.weight torch.Size([4096, 16384])
base_model.model.transformer.layers.0.mlp.dense_4h_to_h.bias torch.Size([4096])
base_model.model.transformer.layers.1.input_layernorm.weight torch.Size([4096])
base_model.model.transformer.layers.1.input_layernorm.bias torch.Size([4096])

......

        可以看到模型中被添加了LoRA模块(红色部分),是根据全连接“query_key_value”生成的。因为query_key_value层输入是4096,输出是12288,而配置中LoRA的秩是8,所以两个LoRA块是(8,4096)和(12288, 8)

        代码也很好理解,get_peft_model方法将原模型参数冻结并且根据配置向模型中添加LoRA模块。

        解释一下配置LoraConfig,下面是这个对象的主要参数:

 1.task_type:

        SEQ_CLS:序列分类(Sequence Classification)任务。这种任务涉及对输入序列整体进行分类,例如情感分析、文本分类等。

        SEQ_2_SEQ_LM:序列到序列语言建模(Sequence-to-Sequence Language Modeling)任务。这种任务能够将一个输入序列映射到另一个输出序列,例如机器翻译、文本摘要等。

        CAUSAL_LM:因果语言建模(Causal Language Modeling)任务。这种任务涉及训练一个模型,使其能够预测给定先前上下文的下一个标记,例如自动补全、语言生成等。

        TOKEN_CLS:标记分类(Token Classification)任务。这种任务涉及对输入序列中的每个标记进行分类,例如命名实体识别、词性标注等。

        QUESTION_ANS:问答(Question Answering)任务。这种任务涉及根据给定的问题和相关的上下文文本来预测答案。输入是Prompt+问题。

        FEATURE_EXTRACTION:特征提取(Feature Extraction)任务。这种任务涉及从文本或序列中提取有用的特征,以供其他任务或模型使用。

2.r:LoRA秩的维度,这数越大,微调带来的“影响”越强,但是需要训练的参数量会增加。

3.lora_alpha:LoRA在前向传播的过程中引入一个额外的扩展系数(scaling coefficient),用于将LoRA权重应用于预训练权重。这个数越大,LoRA权重的影响就越大。

4.target_modules:要施加LoRA的模块名称,需要注意的是,参数是字符串数组,模块类型必须是`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`中的一个。比如这个例子中还可以填写"word_embeddings"和"dense"。

三、完整的训练代码

        现在给出一个完整的基于LoRA的ChatGLM训练代码,peft库在原模型基础上添加LoRA非常方便,对代码的侵入也很小。下面的代码我添加了注释,流程还是很清楚的:

from transformers.integrations import TensorBoardCallback
from torch.utils.tensorboard import SummaryWriter
from transformers import TrainingArguments
from transformers import Trainer, HfArgumentParser
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from peft import get_peft_model, LoraConfig, TaskType
from dataclasses import dataclass, field
import datasets
import os


tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)


@dataclass
class FinetuneArguments:
    dataset_path: str = field(default="data/alpaca")
    model_path: str = field(default="output")
    lora_rank: int = field(default=8)


class CastOutputToFloat(nn.Sequential):
    def forward(self, x):
        return super().forward(x).to(torch.float32)


def data_collator(features: list) -> dict:
    len_ids = [len(feature["input_ids"]) for feature in features]
    longest = max(len_ids)
    input_ids = []
    labels_list = []
    for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
        ids = feature["input_ids"]
        seq_len = feature["seq_len"]
        labels = (
            [-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
        )
        ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
        _ids = torch.LongTensor(ids)
        labels_list.append(torch.LongTensor(labels))
        input_ids.append(_ids)
    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels_list)
    return {
        "input_ids": input_ids,
        "labels": labels,
    }



class ModifiedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

    def save_model(self, output_dir=None, _internal_call=False):
        self.model.save_pretrained(output_dir)


def main():
    writer = SummaryWriter()
    # 组织训练参数
    finetune_args, training_args = HfArgumentParser(
        (FinetuneArguments, TrainingArguments)
    ).parse_args_into_dataclasses()

    # init model
    model = AutoModel.from_pretrained(
        "THUDM/chatglm-6b", load_in_8bit=False, trust_remote_code=True, device_map="auto", local_files_only=True
    ).float()
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()
    # 模型是可以并行化的。
    model.is_parallelizable = True
    # 启用模型的并行化。
    model.model_parallel = True
    # 将模型的 lm_head(语言模型头)的输出转换为浮点数类型。
    model.lm_head = CastOutputToFloat(model.lm_head)
    # 禁用模型配置中的缓存,用于禁止缓存中间结果,可以减少显存占用,但是训练时间会变长
    model.config.use_cache = (
        False  # silence the warnings. Please re-enable for inference!
    )

    # LoRA配置
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=finetune_args.lora_rank,
        lora_alpha=32,
        lora_dropout=0.1,
    )
    # 对模型使用LoRA
    model = get_peft_model(model, peft_config).float()

    # 使用alpaca数据集
    dataset = datasets.load_from_disk(finetune_args.dataset_path)
    print(f"\n{len(dataset)=}\n")

    # for d in dataset.iter(batch_size=1):
    #     print("d:", d)

    # start train
    trainer = ModifiedTrainer(
        model=model,
        train_dataset=dataset,
        args=training_args,
        callbacks=[TensorBoardCallback(writer)],
        data_collator=data_collator,
    )
    trainer.train()
    writer.close()
    # 存训练后的参数
    model.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
    main()

        训练之后模型文件会保存在output_dir目录中。到这里我们发现一个问题,毕竟LoRA在原模型的基础上加了分支,这会带来推理效率的降低,其实我们调用merge_and_unload方法就能将LoRA的分支模块合并到基础模型,推理代码如下:

from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModel, AutoModelForSeq2SeqLM
import torch
from transformers import AutoTokenizer

# 加载基础模型
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)

# 配置LoRA
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=True,
    target_modules=['query_key_value'],
    r=8, lora_alpha=32, lora_dropout=0.1
)
# 对模型使用LoRA
model = get_peft_model(model, peft_config).half()
# 加载LoRA参数
model.load_state_dict(torch.load("output/checkpoint-1000/adapter_model.bin", map_location=torch.device("cuda")), strict=False)
# 将LoRA的分支模块合并到基础模型
model.merge_and_unload()

while True:
    prompt = input("Prompt: ")
    inputs = tokenizer(prompt, return_tensors="pt")
    model.params_dtype = torch.float32
    response = model.generate(input_ids=inputs["input_ids"],
                              max_length=inputs["input_ids"].shape[-1] + 128)
    response = response[0, inputs["input_ids"].shape[-1]:]
    print("responseL", response)
    for r in response:
        print(r, ":", tokenizer.decode([r], skip_special_tokens=False))
    print("Response:", tokenizer.decode(response, skip_special_tokens=True))

四、总结

1.LoRA的实现方式是在原模型的线性变换模块(全连接、Embedding、卷积)旁边增加一个旁路,通过低秩分解(先降维再升维)来模拟参数的更新量。

2.LoRA模块由两个小矩阵组成,这两个矩阵内积的输入输出形状与原模型一致,大模型需要的“新知识”就存在这个模块中;

3.秩可以很小,有实验表明,就算秩=1,效果也不是很差;

4.尽量多的对模型中的线性变换模块使用秩很小LoRA;而不是对一个模块使用秩很大的LoRA;

5.推理时需要做参数合并,就是将AB的内积加到原参数上,从而不引入额外的推理延迟;

5.LoRA智能一定程度提升模型在某个领域的能力,并不能使模型发生根本性的能力提升。

LoRA就介绍到这里,关注不迷路(#^.^#)

关注订阅号了解更多精品文章

交流探讨、商务合作请加微信

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1669623.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高效快速 推荐这款服务器同步软件

服务器数据同步是为了确保在不同的服务器或数据中心之间能够保持数据的一致性和可用性,选择一款合适的服务器同步软件,可确保数据完整性、提高服务质量和满足业务需求的重要手段。 服务器数据同步的痛点主要包括: 1、数据一致性:…

SQL-递归查询

运行环境: Mysql8以上,递归查询功能在8以上版本被正式引入 一、SQL递归查询的概念 递归指的是通过调用函数或过程或自身来解决问题的方法,常用于一些具有规律性循环的操作。SQL递归查询是基于一组初始数据,通过递归查询&#xf…

Redis继续(黑马)

Redis持久化 RDB与AOF RDB记录是二进制数据,Redis停机时会触发保存,名称: dump.rdb 缺点:间歇式复制可能存在宕机数据更新丢失 AOF 记录的写操作命令,每秒记录一下,也存在数据更新丢失的可能,相…

【class6】人工智能初步(选择一个合适的监督学习算法。)

【昨日内容复习】 进行监督学习时,第一个步骤是提取数据集的文本特征和对应的标签。 提取文本特征的具体步骤如下: STEP1. 构造词袋模型,提取数据集中的文本特征 STEP2. 使用toarray()函数,将X转换为一个NumPy数组,方…

【5月13日】YesPMP众包平台最新项目

YesPMP众包平台5月13日最新项目,有感兴趣的用户查看项目接单,甲乙方无障碍沟通。 1.查看项目:分析一款PC端登录协议及收发消息 2.查看项目:《中华历史漫画》 3.查看项目:图像算法 …

什么是CCRC?做什么用的?

CCRC(中国网络安全审查认证和市场监管大数据中心)原名为中国网络安全审查技术与认证中心,也被称为中国信息安全认证中心(ISCCC)。 该中心是经中央机构编制委员会办公室批准成立的,其主要职责是依据国家法律…

设计模式 六大原则之开放封闭原则

文章目录 定义理解 小结 定义 开闭原则规定软件中的对象、类、模块和函数对扩展应该是开放的,但对于修改是封闭的。这意味着应该用抽象定义结构,用具体实现扩展细节,以此确保软件系统开发和维护过程的可靠性。 理解 怎么理解这个呢&#x…

【IMX6ULL项目】IMX6ULL上Linux系统实现产测工具框架

电子产品量产测试与烧写工具。这是一套软件,用在我们的实际生产中, 有如下特点: 1.简单易用: 把这套软件烧写在 SD 卡上,插到 IMX6ULL 板子里并启动,它就会自动测试各个模块、烧写 EMMC 系统。 工人只要按…

【C语言】深度解析:动态内存管理的机制与实践

🔥引言 本篇将深度解析:动态内存管理的机制。为了更加灵活分配内存中的空间,库中为了我们提供了一些的函数,去动态开辟和释放堆上的空间。 🌈个人主页:是店小二呀 🌈C语言笔记专栏:C语言笔记 &a…

询问贴:这要怎么设置捏,寻思着总不该一个一个挖空吧????

这要怎么设置捏,寻思着总不该一个一个挖空吧????

Hadoop3.4.0 完全分布式集群 运行环境搭建 VMware Workstation 虚拟机 大数据系列 一

一 生产环境集群模式部署,需要多台主机,主机之间通过密钥相互访问. 1 配置如图 节点名字节点IP系统版本master11192.168.50.11centos 8.5slave12192.168.50.12centos 8.5slave13192.168.50.13centos 8.5 2 安装服务器 #先安装一台master11&#xff…

google test 使用指南

目录 测试项目 calculator.h calculator.cpp test01.cpp 创建新项目 选择Google Test 选择要测试的项目 pch.cpp 加入依赖 设为启动项目 ​编辑 运行 ​编辑 关键点 测试项目 calculator.h #ifndef __CALCULATOR_H__ #define __CALCULATOR_H__#include <i…

创vite项目时报错【文件名、目录名或卷标语法不正确】

错误提示 错误原因 yarn的安装包默认是在C盘的而我电脑上yarn安装在D盘&#xff0c;所以就会报这样的错误。 可以使用如下命令查看当前yarn的安装包位置 yarn global dir 解决办法 1、将yarn的全局路径改到D盘就可以了&#xff0c;在D盘创建yarn文件夹&#xff0c;然后再其…

8.微服务项目结合SpringSecurity项目结构

项目结构 acl_parent:创建父工程用来管理依赖版本 common service_base&#xff1a;工具类 spring_security: Spring Security相关配置 infrastructure api_gateway: 网关 service service_acl: 实现权限管理功能代码 acl_parent的pom.xml <?xml version"1.0" …

2万字干货:如何从0到1搭建一套会员体系(4)

开始本节前还是一样来个灵魂发问&#xff1a;为什么产品需要用户标签&#xff0c;或者用户标签有什么意义/价值&#xff1f; 某些业务场景下使用会员等级无法满足业务需要。比如新用户激活、老用户福利以及沉默客户唤醒等等。 用户等级划分的逻辑和维度有些局限性&#xff0c;…

小区物业管理系统

文章目录 小区物业管理系统一、项目演示二、项目介绍三、部分功能截图四、部分代码展示五、底部获取项目源码&#xff08;9.9&#xffe5;带走&#xff09; 小区物业管理系统 一、项目演示 小区物业管理系统 二、项目介绍 基于springbootvue的前后端分离物业管理系统 系统角…

遨游 JavaScript 对象星际:探索面向对象编程的深邃世界

个人主页&#xff1a;学习前端的小z 个人专栏&#xff1a;JavaScript 精粹 本专栏旨在分享记录每日学习的前端知识和学习笔记的归纳总结&#xff0c;欢迎大家在评论区交流讨论&#xff01; 文章目录 &#x1f4af;面向对象编程&#x1f517;1 什么是对象&#x1f517;2 什么是…

互联网引流艺术:精准获客的黄金法则

在如今这个信息爆炸的时代&#xff0c;互联网引流不再是简单地发布广告和等待潜在客户的到来。它变成了一门需要策略、技巧和持续创新的艺术。作为一位资深的互联网营销从业者&#xff0c;我深知精准推广的重要性&#xff0c;以及它在帮助企业获得理想客户中的关键作用。以下是…

(七)SQL基础知识练习题(选择题)(上)#CDA学习打卡

本文整理了SQL基础知识相关的练习题&#xff0c;共133道&#xff0c;可作为CDA一级的补充习题&#xff0c;也适用于刚入门初级SQL想巩固基础的同学。来源&#xff1a;如荷学数据科学题库&#xff08;技术专项-SQL&#xff09;。暂时按照原题库顺序present&#xff0c;如有需要之…

JINGWHALE 数字认证体系 · 进阶知识库

JINGWHALE 数字认证体系 是 JINGWHALE 数字科学艺术创新中心 的数字认证服务。 ◢◤ 宗旨 致力于数字化知行合一的知识赋能&#xff01; ◥ 数字化人才培养 培养数字化思维&#xff0c;传播数字化知识&#xff0c;赋能各行业数字化。 ◥ 职业人才发展 无缝衔接学校高等…