以bert为例,了解Lora是如何添加到模型中的

news2024/10/5 18:31:02

以bert为例,了解Lora是如何添加到模型中的

  • 一.效果图
    • 1.torch.fx可视化
      • A.添加前
      • B.添加后
    • 2.onnx可视化
      • A.添加前
      • B.添加后
    • 3.tensorboard可视化
      • A.添加前
      • B.添加后
  • 二.复现步骤
    • 1.生成配置文件(num_hidden_layers=1)
    • 2.运行测试脚本

本文以bert为例,对比了添加Lora模块前后的网络结构图
说明:

  • 1.为了加快速度,将bert修改为一层
  • 2.lora只加到intermediate.dense,方便对比
  • 3.使用了几种不同的可视化方式(onnx可视化,torchviz图,torch.fx可视化,tensorboard可视化)

可参考的点:

  • 1.peft使用
  • 2.几种不同的pytorch模型可视化方法

一.效果图

1.torch.fx可视化

A.添加前

在这里插入图片描述

B.添加后

在这里插入图片描述

2.onnx可视化

A.添加前

在这里插入图片描述

B.添加后

在这里插入图片描述

3.tensorboard可视化

A.添加前

在这里插入图片描述

B.添加后

在这里插入图片描述

二.复现步骤

1.生成配置文件(num_hidden_layers=1)

tee ./config.json <<-'EOF'
{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 1,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 21128
}
EOF

2.运行测试脚本

tee bert_lora.py <<-'EOF'
import time
import os
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.init as init
import time
import numpy as np
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter
from torch._functorch.partitioners import draw_graph

def onnx_infer_shape(onnx_path):
    import onnx
    onnx_model  = onnx.load_model(onnx_path)
    new_onnx= onnx.shape_inference.infer_shapes(onnx_model)
    onnx.save_model(new_onnx, onnx_path)

def get_model():
    torch.manual_seed(1)
    from transformers import AutoModelForMaskedLM,BertConfig
    config=BertConfig.from_pretrained("./config.json")
    model = AutoModelForMaskedLM.from_config(config)
    return model,config

def my_compiler(fx_module: torch.fx.GraphModule, _):
    draw_graph(fx_module, f"bert.{time.time()}.svg")
    return fx_module.forward

if __name__ == "__main__":

    model,config=get_model()
    model.eval()
    input_tokens=torch.randint(0,config.vocab_size,(1,128))
    
    # 一.原始模型
    # 1.onnx可视化
    torch.onnx.export(model,input_tokens,
                  "bert_base.onnx",
                  export_params=False,
                  opset_version=11,
                  do_constant_folding=True)
    onnx_infer_shape("bert_base.onnx")
    
    # 2.torchviz图
    output = model(input_tokens)
    logits = output.logits
    viz = make_dot(logits, params=dict(model.named_parameters()))
    viz.render("bert_base", view=False)
    
    # 3.torch.fx可视化
    compiled_model = torch.compile(model, backend=my_compiler)
    output = compiled_model(input_tokens)

    # 4.tensorboard可视化
    writer = SummaryWriter('./runs')
    writer.add_graph(model, input_to_model = input_tokens,use_strict_trace=False)
    writer.close()
    
    # 二.Lora模型
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=True,
        r=8,
        lora_alpha=32,
        target_modules=['intermediate.dense'],
        lora_dropout=0.1,
    )
    lora_model = get_peft_model(model, peft_config)
    lora_model.eval()
    torch.onnx.export(lora_model,input_tokens,
                      "bert_base_lora_inference_mode.onnx",
                      export_params=False,
                      opset_version=11,
                      do_constant_folding=True)
    onnx_infer_shape("bert_base_lora_inference_mode.onnx")

    compiled_model = torch.compile(lora_model, backend=my_compiler)
    output = compiled_model(input_tokens)

    writer = SummaryWriter('./runs_lora')
    writer.add_graph(lora_model, input_to_model = input_tokens,use_strict_trace=False)
    writer.close()
EOF

# 安装依赖
apt install graphviz -y
pip install torchviz
pip install pydot

# 运行测试程序
python bert_lora.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1819499.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Spine学习06】之IK约束绑定,制作人物待机动画,图表塞贝尔曲线优化动作

引入IK约束的概念&#xff1a; 约束目标父级 被约束骨骼子集 这样理解更好&#xff0c;约束目标可以控制被约束的两个骨骼运作 IK约束绑定过程中呢&#xff0c;如果直接绑定最下面的脚掌骨骼会发生偏移&#xff0c;所以在开始处理IK之前&#xff0c;需要先设置一个ROOT结点下的…

签到的二维码怎么制作?快速实现制作二维码签到的方法

现在很多活动会采用二维码的方式来做登记、报名、签到等&#xff0c;通过二维码可以快速获取用户信息&#xff0c;并且对于用户填写内容也提升了便利性&#xff0c;而且还能够节约成本&#xff0c;通过后台就可以查看用户登记的数据&#xff0c;方便后期的分析和信息管理。 想…

自监督分类网络:创新的端到端学习方法

现代人工智能的快速发展中&#xff0c;分类任务的高效解决方案一直备受关注。今天&#xff0c;我们向大家介绍一种名为Self-Classifier的全新自监督端到端分类学习方法。由Elad Amrani、Leonid Karlinsky和Alex Bronstein团队开发&#xff0c;Self-Classifier通过优化同一样本的…

【机器学习】QLoRA:基于PEFT亲手微调你的第一个AI大模型

目录 一、引言 二、量化与微调—原理剖析 2.1 为什么要量化微调? 2.2 量化&#xff08;Quantization&#xff09; 2.2.1 量化原理 2.2.2 量化代码 2.3 微调&#xff08;Fine-Tuning&#xff09; 2.3.1 LoRA 2.3.2 QLoRA 三、量化与微调—实战演练&#xff1a;以Qwen…

Photoshop 2024 mac/win版:探索图像处理的全新境界

Photoshop 2024是Adobe推出的最新图像处理与设计软件&#xff0c;它在继承了前作所有优秀特性的基础上&#xff0c;实现了多个方面的质的飞跃。这款软件凭借其卓越的图像处理性能、丰富的创意工具以及精确的选区编辑功能&#xff0c;成为了图像处理领域的佼佼者。 Photoshop 2…

Golang免杀-分离式加载器(传参)AES加密

目录 enc.go 生成: dec.go --执行dec.go...--上线 cs生成个c语言的shellcode. enc.go go run .\enc.go shellcode 生成: --key为公钥. --code为AES加密后的数据, ----此脚本每次运行key和code都会变化. package mainimport ("bytes""crypto/aes"&…

redis 08 慢查询日志

1.什么是慢查询日志 2.慢查询和两个参数有关 2.1 2.2 3.例子&#xff1a; 4 参数详细介绍&#xff1a;

共模信号与差模信号

差模信号又称串模信号&#xff0c;指的是两根线之间的信号差值&#xff1b;而共模信号又称对地信号&#xff0c;指的是两根线分别对地的信号。 差模信号&#xff1a;大小相等&#xff0c;方向相反的信号。共模信号&#xff1a;大小相等&#xff0c;方向相同的信号。 对于两输…

集合查询-并(UNION)集运算、交(INTERSECT)集运算、差(EXCEPT)集运算

一、概述 集合查询是对两个SELECT语句的查询结果进行再进行处理的查询 二、条件 1、两个SELECT语句的查询结果必须是属性列数目相同 2、两个SELECT语句的查询结果必须是对应位置上的属性列必须是相同的数据类型 三、并(UNION)运算 1、语法格式&#xff1a; SELECT 语句1…

4090显卡 安装cuda 11.3 版本

文章目录 cuda 安装安装过程中会要求选择安装的内容更改cuda地址到你安装的地方 cuda 安装 cuda官网寻找cuda11.3 版本 https://developer.nvidia.com/cuda-11.3.0-download-archive?target_osLinux&target_archx86_64&DistributionUbuntu&target_version20.04&…

和利时DCS数据采集对接安监平台

在工业互联网日益繁荣的今天&#xff0c;工业数据的采集、传输与利用变得至关重要。特别是在工业自动化领域&#xff0c;数据的实时性和准确性直接关系到生产效率和安全性。和利时DCS&#xff08;分布式控制系统&#xff09;以其卓越的稳定性和可靠性&#xff0c;在工业自动化领…

被封号后,我终于明白免费代理的危害

在数字时代&#xff0c;网络已经成为人们日常生活和商业活动中不可或缺的一部分。为了实现更广阔的业务拓展和更畅通的网络体验&#xff0c;许多人开始考虑使用代理服务器。然而&#xff0c;虽然免费代理可能听起来像是个经济实惠的选择&#xff0c;但事实上&#xff0c;它可能…

SSH协议

SSH协议简介 SSH&#xff08;Secure Shell&#xff09;是一种网络安全协议&#xff0c;用于在不安全的网络环境中提供加密的远程登录和其他网络服务。它通过加密和认证机制实现安全的访问和文件传输等业务&#xff0c;是Telnet、FTP等不安全远程shell协议的安全替代方案。 SSH协…

数据挖掘丨轻松应用RapidMiner机器学习内置数据分析案例模板详解(下篇)

RapidMiner 案例模板 RapidMiner 机器学习平台提供了一个可视化的操作界面&#xff0c;允许用户通过拖放的方式构建数据分析流程。RapidMiner目前内置了 13 种案例模板&#xff0c;这些模板是预定义的数据分析流程&#xff0c;可以帮助用户快速启动和执行常见的数据分析任务。 …

大模型微调出错的解决方案(持续更新)

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

关于python下安装selenium以及使用

&#x1f4d1;打牌 &#xff1a; da pai ge的个人主页 &#x1f324;️个人专栏 &#xff1a; da pai ge的博客专栏 ☁️宝剑锋从磨砺出&#xff0c;梅花香自苦寒来 目录 1、win10安装python环境 2、…

【第6章】Vue生命周期

文章目录 前言一、生命周期1. 两大类2. 生命周期 二、选项式生命周期1. 代码2. 效果 三、组合式生命周期1. 代码2. 效果2.1 挂载和更新2.2 卸载和挂载 总结 前言 每个 Vue 组件实例在创建时都需要经历一系列的初始化步骤&#xff0c;比如设置好数据侦听&#xff0c;编译模板&a…

【MySQL】MySQL45讲-读书笔记

1、基础架构&#xff1a;一条SQL查询语句是如何执行的&#xff1f; 1.1 连接器 连接器负责跟客户端建立连接、获取权限、维持和管理连接。 mysql -h$ip -P$port -u$user -p输完命令之后&#xff0c;输入密码。 1.2 查询缓存 MySQL 拿到一个查询请求后&#xff0c;会先到查询缓…

AlmaLinux 8.10 x86_64 OVF (sysin) - VMware 虚拟机模板

AlmaLinux 8.10 x86_64 OVF (sysin) - VMware 虚拟机模板 AlmaLinux release 8.10 请访问原文链接&#xff1a;https://sysin.org/blog/almalinux-8-ovf/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org 2023.03.08 更新&…

自动控制原理【期末复习】(二)

无人机上桨之后可以在调试架上先调试&#xff1a; 1.根轨迹的绘制 /// 前面针对的是时域分析&#xff0c;下面针对频域分析&#xff1a; 2.波特图 3.奈维斯特图绘制 1.奈氏稳定判据 2.对数稳定判据 3.相位裕度和幅值裕度