LoRA微调大语言模型Bert

news2024/11/24 22:30:52

LoRA是一种流行的微调大语言模型的手段,这是因为LoRA仅需在预训练模型需要微调的地方添加旁路矩阵。LoRA 的作者们还提供了一个易于使用的库 loralib,它极大地简化了使用 LoRA 微调模型的过程。这个库允许用户轻松地将 LoRA 层添加到现有的模型架构中,而无需深入了解其底层实现细节。这使得 LoRA 成为了一种非常实用的技术,既适合研究者也适合开发人员。下面给出了一个LoRA微调Bert模型的具体例子。
下图给出了一个LoRA微调Bert中自注意力矩阵 W Q W^Q WQ的例子。如图所示,通过冻结矩阵 W Q W^Q WQ,并且添加旁路低秩矩阵 A , B A,B A,B来进行微调。同理,使用LoRA来微调 W K W^K WK也是如此。
image.png
我们给出了通过LoRA来微调Bert模型中自注意力矩阵的具体代码。代码是基于huggingface中Bert开源模型进行改造。Bert开源项目链接如下:
https://huggingface.co/transformers/v4.3.3/_modules/transformers/models/bert/modeling_bert.html

基于LoRA微调的代码如下:
# 环境配置
# pip install loralib
# 或者
# pip install git+https://github.com/microsoft/LoRA
import loralib as lora

class LoraBertSelfAttention(BertSelfAttention):
    """
    继承BertSelfAttention模块
    对Query,Value用LoRA进行微调
    
    参数:
    - r (int): LoRA秩的大小
    - config: Bert模型的参数配置
    """
    def __init__(self, r=8, *config):
        super().__init__(*config)
        # 获得所有的注意力的头数
        d = self.all_head_size 
        # 使用LoRA提供的库loralib
        self.lora_query = lora.Linear(d, d, r)
        self.lora_value = lora.Linear(d, d, r)
        
    def lora_query(self, x):
        """
        对Query矩阵执行Wx + BAx操作
        """
        return self.query(x) + F.linear(x, self.lora_query)
    
    def lora_value(self, x):
        """
        对Value矩阵执行Wx + BAx操作
        """
        return self.value(x) + F.linear(x, self.lora_value)
    
    
    def forward(self, hidden_states, *config):
        """
        更新涉及到Query矩阵和Value矩阵的操作
        """
        # 通过LoRA微调Query矩阵
        mixed_query_layer = self.lora_query(hidden_states)
        is_cross_attention = encoder_hidden_states is not None
        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            # 通过LoRA微调Value矩阵
            value_layer = self.transpose_for_scores(self.lora_value(hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            # 通过LoRA微调Value矩阵
            value_layer = self.transpose_for_scores(self.lora_value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            # 通过LoRA微调Value矩阵
            value_layer = self.transpose_for_scores(self.lora_value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        if self.is_decoder:
            past_key_value = (key_layer, value_layer)
        # Query矩阵与Key矩阵算点积得到注意力分数
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility
            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.dropout(attention_probs)
        if head_mask is not None:
            attention_probs = attention_probs * head_mask
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

class LoraBert(nn.Module):
    def __init__(self, task_type, num_classes=None, dropout_rate=0.1, model_id="bert-base-cased",
                 lora_rank=8, train_biases=True, train_embedding=False, train_layer_norms=True):
        """
        - task_type: 设计任务的类型,如:'glue', 'squad_v1', 'squad_v2'.
        - num_classes: 分类类别的数量.
        - model_id: 预训练好的Bert的ID,如:"bert-base-uncased","bert-large-uncased".
        - lora_rank: LoRA秩的大小.
        - train_biases, train_embedding, train_layer_norms: 这是参数是否需要训练    
        """
        super().__init__()
        # 1.加载权重
        self.model_id = model_id
        self.tokenizer = BertTokenizer.from_pretrained(model_id)
        self.model = BertForPreTraining.from_pretrained(model_id)
        self.model_config = self.model.config
        # 2.添加模块
        d_model = self.model_config.hidden_size
        self.finetune_head_norm = nn.LayerNorm(d_model)
        self.finetune_head_dropout = nn.Dropout(dropout_rate)
        self.finetune_head_classifier = nn.Linear(d_model, num_classes)
        # 3.通过LoRA微调模型
        self.replace_multihead_attention()
        self.freeze_parameters()
        
    def replace_self_attention(self, model):
        """
        把预训练模型中的自注意力换成自己定义的LoraBertSelfAttention
        """
        for name, module in model.named_children():
            if isinstance(module, RobertaSelfAttention):
                layer = LoraBertSelfAttention(r=self.lora_rank, config=self.model_config)
                layer.load_state_dict(module.state_dict(), strict=False)
                setattr(model, name, layer)
            else:
                self.replace_self_attention(module)
                
                
    def freeze_parameters(self):
        """
        将除了涉及LoRA微调模块的其他参数进行冻结
        LoRA微调影响到的模块: the finetune head, bias parameters, embeddings, and layer norms 
        """
        for name, param in self.model.named_parameters():
            is_trainable = (
                "lora_" in name or
                "finetune_head_" in name or
                (self.train_biases and "bias" in name) or
                (self.train_embeddings and "embeddings" in name) or
                (self.train_layer_norms and "LayerNorm" in name)
            )
            param.requires_grad = is_trainable
	peft库中包含了LoRA在内的许多大模型高效微调方法,并且与transformer库兼容。使用peft库对大模型flan-T5-xxl进行LoRA微调的代码例子如下:


# 通过LoRA微调flan-T5-xxl
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
# 模型介绍:https://huggingface.co/google/flan-t5-xxl
model_name_or_path = "google/flan-t5-xxl"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")
peft_config = LoraConfig(
 r=8,
 lora_alpha=16, 
 target_modules=["q", "v"], # 仅对Query,Value矩阵进行微调
 lora_dropout=0.1,
 bias="none", 
 task_type=TaskType.SEQ_2_SEQ_LM
)
model = get_peft_model(model, peft_config)
# 打印可训练的参数
model.print_trainable_parameters()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2036091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springBoot整合xxl-job开箱即用

一、搭建xxl-job任务调用中心 1. 下载地址: xxl-job: 一个分布式任务调度平台,其核心设计目标是开发迅速、学习简单、轻量级、易扩展。现已开放源代码并接入多家公司线上产品线,开箱即用。 git拉取后,本地打开,并进…

haproxy总结与实验

一、负载均衡 1.1 简述负载均衡 在高并发的业务场景下,解决单个节点压力过大,导致Web服务响应过慢,特别是严重的情况下导致服务瘫痪,无法正常提供服务的问题,而负载均衡的目的就是为了维护系统稳定可靠。负载均衡&…

汽车补光照明实验太阳光模拟器光源

汽车补光照明实验概览 汽车补光照明实验是汽车照明领域的一个重要环节,它涉及到汽车照明系统的性能测试和优化。实验的目的在于确保汽车在各种光照条件下都能提供良好的照明效果,以提高行车安全。实验内容通常包括但不限于灯光的亮度、色温、均匀性、响应…

奥运科技观察:AI PC,如何成为当代体育精神的数字捍卫者?

作者 | 曾响铃 文 | 响铃说 数字孪生帮助体育馆建设、超高清直播……这届奥运会科技感拉满,几乎所有前沿技术都能在奥运的赛事运营中发现。 而AI大时代,AI如何帮助帮助奥运会顺利举办、如何帮助运动员拥有更好的表现,同样值得业界关注&…

洛谷P3919 【模板】可持久化线段树 1(可持久化数组)

目录 tags中文题面思路代码 tags 线段树 主席树 中文题面 如题,你需要维护这样的一个长度为 N 的数组,支持如下几种操作 在某个历史版本上修改某一个位置上的值访问某个历史版本上的某一位置的值此外,每进行一次操作(对于操作…

Mybatis PLUS代码生成器generate

Mybatis PLUS代码生成器generate 一、2.3版本二、生成代码三、3.5.1版本四、生成代码 一、2.3版本 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-generator</artifactId><version>2.3</version> </dep…

Java 随机生成密码包含大写字母、数字、特殊字符且指定长度

一、写在前面 现在网络环境越来越复杂&#xff0c;对密码安全要求也越来越严格&#xff0c;在生产环境种&#xff0c;对密码要求是一个不少于16位的随机密码&#xff0c;要求含有大写字母、小写字母、数字、特殊字符中的三种。我们使用java代码直接来可控的生成这种密码。 二…

数字县域+乡村振兴解决方案

1. 国家大数据战略与乡村振兴 国家大数据战略的核心内容包括加快建设数字中国&#xff0c;推动数据资源整合和开放共享&#xff0c;以大数据助力产业转型升级和社会治理创新&#xff0c;构建数字经济&#xff0c;提升国家治理现代化水平。 2. 乡村振兴战略的重大意义 乡村振…

【C++】特殊类设计 — 不能被拷贝的类 , 只能在堆/栈上创建对象的类 ,不能被继承的类

苟活者在淡红的血色中&#xff0c;会依稀看见微茫的希望&#xff1b; 真的猛士&#xff0c;将更奋然而前行。 --- 鲁迅 --- toc 1 特殊类 在实践中&#xff0c;常常会有一些比较有意思的特殊场景&#xff1a; 不能被拷贝的类 - 独一无二的魔法宝物&#xff1a; 在一个角色…

『大模型笔记』虚拟机(Virtual Machine,VM)与Docker对比!

『大模型笔记』虚拟机(Virtual Machine,VM)与Docker对比! 文章目录 一. 虚拟机(Virtual Machine,VM)与Docker对比!1. 定义这两种技术2. 工作原理3. 关于如何选择适合工作负载的技术的指导二. 参考文献Docker 只是一个轻量级的虚拟机吗?虽然二者确实有一个共同点,即 虚…

【RISC-V设计-13】- RISC-V处理器设计K0A之指令测试

【RISC-V设计-13】- RISC-V处理器设计K0A之指令测试 文章目录 【RISC-V设计-13】- RISC-V处理器设计K0A之指令测试1.简介2.验证用例3.指令代码4.链接脚本5.编译脚本6.仿真结果6.1 复位结束6.2 运行成功6.3 终端打印 7.总结 1.简介 借助上一篇文章所提及的验证环境&#xff0c;…

对象引用对于非静态的字段、方法或属性是必需的

CS0120 对象引用对于非静态的字段、方法或属性“Person.FirstName”是必需的 类Person internal class Person{// public static string FirstName { get;set; }"sss";public string FirstName { get; set; } "sss";public static string MiddleName …

k8s挂载nginx配置文件

文章目录 步骤一&#xff1a;启动指定服务的工作负载时&#xff0c;指定需要挂载的配置文件&#xff0c;替换工作负载内置的配置文件步骤二: 在配置字典中新增配置文件步骤三&#xff1a;自定义挂载的配置文件 步骤一&#xff1a;启动指定服务的工作负载时&#xff0c;指定需要…

深度学习 —— 个人学习笔记17(锚框、多尺度锚框)

声明 本文章为个人学习使用&#xff0c;版面观感若有不适请谅解&#xff0c;文中知识仅代表个人观点&#xff0c;若出现错误&#xff0c;欢迎各位批评指正。 三十四、锚框 import torch import matplotlib.pyplot as plt from matplotlib_inline import backend_inlinetorch.…

Python 绘图进阶之箱线图:揭示数据的分布和异常值

Python 绘图进阶之箱线图&#xff1a;揭示数据的分布和异常值 引言 在数据分析中&#xff0c;理解数据的分布情况和识别异常值是非常重要的任务。箱线图&#xff08;Box Plot&#xff09;作为一种简洁有效的统计图表&#xff0c;能够直观地展示数据的中位数、四分位数、极值以…

除了画图,你还需要透视平面设计师的日常工作

平面设计师是市场上较为稀缺且需求旺盛的职业&#xff0c;许多企业都在争相聘请优秀的设计师。平面设计在日常生活中无处不在&#xff0c;应用领域广泛&#xff0c;如广告设计、logo设计和名片设计等。因此&#xff0c;本篇文章将为你详细介绍平面设计。 1、什么是平面设计&am…

YOLT论文精读

引言 很早之前&#xff0c;在本校老师的带领下接触到了目标检测领域。在卫星遥感图像方面有一篇经典的论文《You Only Look Twice: Rapid Multi-Scale Object Detection In Satellite Imagery》。科研小白一开始反复看了几遍也没弄懂&#xff0c;决定写博客来加深自己的理解。…

Vue 3+Vite+Eectron从入门到实战系列之(五)一后台管理登录页

前面已经讲了不少基础知识&#xff0c;这篇开始&#xff0c;我们进行实操&#xff0c;做个后台管理系统&#xff0c;打包成多端的,可安装的桌面app!!其中&#xff0c;登录&#xff0c;退出的提示信息用系统的提示&#xff0c;不使用elemengplus的弹窗提示&#xff01;&#xff…

Java生成图形验证码

1、加依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.16</version></dependency> 2、写接口&#xff0c;这块不需要登录成功才能操作的&#xff0c;所以写controller就行了…

基于Hadoop的网购笔记本电脑大数据分析与可视化系统

文章目录 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主项目介绍数据采集过程数据预处理Hadoop大数据分析可视化展示每文一语 有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主 项目介绍 本项目首先通过爬虫获取京东…