LLM-pruner源码解析

news2024/11/28 12:07:24

1.超参数

模型剪枝的超参数

模型
模型检查点和日志的保存地址
剪枝比例,这里默认0.5
剪枝类型,这里模型L2

 模型生成时的超参数

温度
top_p
最大序列长度

逐通道,逐块,逐层,这个逐层我不记得在论文里面提过啊
layer:保留前n层

注意力模块和线性层的开始层和结束层

剪枝的迭代次数
组的计算策略:这里采用的求和

是否使用全局剪枝

泰勒:包括在向量维度上的,在元素维度上的一阶、二阶(混合参数不知道指的是啥)
向量维度:
更细的元素维度:

这几个生成参数也没啥好说的,加载设备

确定torch的版本

官方给的bloom的配置是7B1的模型

我这里用的3B后面还要根据分析的结果改一下

3B有30层,

bloom7b有24层 这个4,20代表从0开始的排序是20还是从1开始的排序是20呀

按照配置获取的超参数

2.程序运行逻辑

第一步先固定下随机种子

设置日志

获得tokenizer

获得模型,这个类是llm-pruner自己写的,

有个问题:为啥要自己重新写一个加载类呢

自定义加载类

这段代码可以不用看直接看对q,k,v重新排序,这里自己写的加载类和transformers自带的没有区别,猜测应该是大佬防止模块名不一致,自己又重新写了一遍

下面套的类比较多这里为了区别,运行到哪一类提前说一下是继承顺序 

BloomForCausallm:BloomForCausallm类先继承BloomPreTrainedModel类

from transformers.models.bloom.configuration_bloom import BloomConfig

BloomForCausallm继承BloomPreTrainedModel:BloomPreTrainedModel类继承PreTrainedModel类

from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig

BloomForCausallm调用BloomModel:实例化BloomModel类,获取模型结构并初始化权重

BloomForCausallm调用BloomModel,BloomModel继承BloomPreTrainedModel:这个BloomModel类也是继承BloomPreTrainedModel类

BloomForCausallm调用BloomModel,BloomModel调用BloomBlockBloomModel类调用BloomBlock类

BloomForCausallm调用BloomModel,BloomModel调用BloomAttentionBloomAttention

BloomForCausallm调用BloomModel,BloomModel调用BloomMLP:BloomMLP

BloomForCausallm调用BloomModel,BloomModel调用BloomGelu:BloomGelu

BloomForCausallm调用BloomModel,BloomModel继承PreTrainedModel中的self.post_init()实例方法:BloomModel中的post_init,这个方法是PreTrainedModel中的方法

BloomForCausallm调用BloomModel,BloomModel继承PreTrainedModel中的self.post_init()实例方法,self.post_init()调用PreTrainedModel中的self.init_weight()实例方法和self._back_compatibity_gradient_ckeckpointing()实例方法
init_weights:如果需要,修剪并可能初始化权重。如果使用自定义的 `PreTrainedModel`,你需要在 `_init_weights` 中实现任何初始化逻辑。

BloomForCausallm:对于 self.transformer,post_init方法中的判定都是否,所以还是保持原来的参数不变,并没有对self.transformer包含的模块进行重新初始化

BloomForCausallm:self.lm_head 

BloomForCausallm:self.post_init()判断都是否,没有对网络进行重新初始化

BloomForCausallm继承BloomPreTrainedModel: 我还是没明白什么时候调用的_init_weight方法,只对self.lm_head进行标准化

q,k,v重排

从这里开始需要看,将q,k,v的顺序进行重排

torch.view 函数的追踪在某些情况下比较复杂,因此,查询、键、值的索引映射有时会遇到问题。为了避免这些问题,函数通过分离查询、键、值的方式来重新组织权重和偏置。


将模型转化为fp16 

from LLMPruner.templates.prompts import prompts

model.generate也是自己写的,先不看这里直接看pruner

pruner

选择pruner类型,将模块参数全都转化为要求梯度,组依赖关系的计算方式选择求和,给定提示prompt

选择taylor方式,一阶,求和,实例化TaylorImportance类

from LLMPruner.pruner import hf_llama_pruner as Pruner

Pruner这个类是自己定义的,如果模型不一样对应的类也不一样,具体怎么根据自己的模型改还得继续向下看

在baichuan中

组依赖关系求和,不进行归一化,一阶泰勒
import LLMPruner.torch_pruning as tp

逐模块计算

获取参数

从开始层到结束层 将q,k,v的pruner比例平分
"ch_sparsity_dict": { model.transformer.h[i].self_attention.query_key_value: args.pruning_ratio / 3 for i in range(args.block_attention_layer_start, args.block_attention_layer_end) },

"root_instances": [model.transformer.h[i].mlp.dense_h_to_4h for i in range(args.block_mlp_layer_start, args.block_mlp_layer_end)] +
                  [model.transformer.h[i].self_attention.query_key_value for i in range(args.block_attention_layer_start, args.block_attention_layer_end)],

开始pruner

import LLMPruner.torch_pruning as tp 

获取MetaPruner类的实例属性

对DependencyGraph类实例化

from ... import ops, dependency

from . import _helpers, utils, ops

这几个函数都是调用的ops中的类,CUSTOMIZED定制为None

已经注册的编辑器,更新编辑器。定制pruner,忽略的层
from .pruner import function

 

刚开始的编辑器

更新之后的编辑器

 按照编辑器中的值获取pruner输入通道的函数,定制里面是空值

 按照编辑器中的值获取pruner输出通道的函数,定制里面是空值

每个pruner类型获取pruner输入/出通道的函数,这里只是举例子

_op_id为0,记录pruner的历史,现在为空 

调用build_dependency方法,输入参数包括模型,前面设置的提示词,刚才参数中的前向函数model(example_inputs),剩下的三个output_transform,unwrapped_parameters,customized_pruners都是None或空值

build_dependency 的实例属性包括:获取模型,命名的模块
customized_pruners,CUSTOMIZED_PRUNERS之前已经知道了两个字典都是空

self._module2name有命名的模块有337个 

检测没有包装的参数

新建已包装模块的列表,获取注册表中的pruner

获得每个模块的类型,如果操作类型在pruner模块类型中且不是逐元素(后面不用看后面是空),把参数加入到已包装参数的类中

已包装的参数有336个

新建unwrapped_detected和_param_to_name

遍历看看哪个是未包装的,加入unwrapped_detected

 

最后的运行结果,没有不被包装的,同样_param_to_name里面也是空的

从 unwrapped_detected 列表中移除所有出现在 unwrapped_parameters 列表中的元素,最终的结果存回 unwrapped_detected 变量,这个unwrapped_parameters是类属性,为空列表

如果 unwrapped_detected不为空的话相关的处理手段,对 unwrapped_detected 列表中的每个元素进行处理,找出其 最后一个大于1的维度,并将该元素与对应的维度信息存入 unwrapped_parameters 列表。最后,将 unwrapped_parameters 保存到 self.unwrapped_parameters

开始追踪计算图了,输入有模型,提示输入,前向,output_transform为None

model.eval()
初始化gradfn2module和visited

获得之前pruner编辑器中的内容 

还是之前的那18个 

  如果模型模块不在忽略的层中且在注册表内,在模型的每一层上注册一个前向钩子(forward_hook),这些钩子将在模型进行前向传播时触发并调用 _record_grad_fn 函数

visited 是一个字典,键是模块对象,值是该模块被调用的次数。每当某个模块被前向传播执行时,visited[module] 增加 1
当前模块是否是 nn.Linear 层,并且该层的输出张量维度是否为 3(例如,(batch_size, seq_len, hidden_dim))。如果条件成立,设置 self._2d_4d = False,表示模型输出的维度不再是 2D 或 4D,而是 3D。
有些层(如 LSTM, GRU)的输出是一个元组,通常包括输出张量和一些附加信息(如隐藏状态)。这个条件会检查输出是否是一个元组,如果是,则提取元组的第一个元素作为最终的输出
PackedSequence 是 PyTorch 中用于表示 RNN 变长序列的输出格式。此条件用于检查 outputs 是否是一个 PackedSequence 对象,如果是,则将其 .data 提取出来。.data 是实际的张量数据
outputs.grad_fn 是 outputs 张量的梯度计算函数(grad_fn),它记录了张量如何计算出来。这里的 gradfn2module 是一个字典,将梯度计算函数 (grad_fn) 映射到对应的模块 module

之前自定义了前向函数 forward_fn,同时会调用之前注册的hook函数 
前向完成后移除掉hook函数

在前向的过程中填充记录模块调用次数的visited 字典和记录计算梯度函数的gradfn2module列表

针对递归模型或层,找到被调用多次的模块记录到reused列表中

这里没有被条用多次的,是空的列表 

这里output_transform是None,如果有的话对模型的输出结果进行转换

from . import _helpers, utils, ops

对于utils.flatten_as_list()

如果 obj 是一个张量,则将其包装成一个列表并返回

检查 obj 是否是一个列表(list)或元组(tuple)。如果是,创建一个空的 flattened_list 用于存储展开后的元素。然后,递归地调用 flatten_as_list 来展开列表或元组中的每个元素(sub_obj)。使用 extend 方法将每次递归得到的展开结果合并到 flattened_list 中。最终返回这个完全展开的列表

检查 obj 是否是一个字典(dict)。如果是,创建一个空的 flattened_list 用于存储展开后的元素。
然后,递归地调用 flatten_as_list 来展开字典中每个键对应的值(sub_obj)。使用 extend 将每次递归得到的展开结果合并到 flattened_list 中。最终返回这个完全展开的列表

如果 obj 既不是 torch.Tensor、列表、元组或字典,那么直接返回 obj 本身。这部分适用于基本类型(如整数、浮点数、字符串等)。

调用是实力属性_trace_computational_graph追踪计算图
输入包括module2node字典,模块梯度计算函数,记录模块计算函数的gradfn2module字典,和纪律被多次调用的模块的字典

 非递归计算图构建

processing_stack: 用于存储待处理的梯度函数(grad_fn)节点,类似栈(stack)的数据结构。
visited: 用于跟踪已经处理过的 grad_fn,避免重复处理。
visited_as_output_node: 用于追踪作为输出节点的计算图节点。

在每次循环中,弹出栈顶的梯度函数(grad_fn)并开始处理。
如果当前 grad_fn 已经处理过,则跳过(防止重复计算)

 调用create_node_if_not_exists

如果 module 已经存在,并且该 module 已经在 module2node 字典中关联了一个节点(即已存在对应的计算图节点),并且该 module 不在 reused 中(表示该节点没有被标记为“已重用”),那么直接返回现有的节点 module2node[module

如果 module 为空(表示这是一个新模块,之前没有创建过),则会根据 grad_fn 创建一个新的模块并与其关联
如果 grad_fn 没有 name 属性(说明它是一个不常见的或自定义的操作),则将其视为一个 逐元素操作(如加法、减法等),并使用 ops._ElementWiseOp 创建一个新的操作模块 module,并给这个模块分配一个唯一的 op_id。self._op_id 会在每次创建模块后自增。
如果 verbose 为 True,则发出警告,提示遇到了一个未知操作,默认将其视为逐元素操作。
如果 grad_fn.name() 包含特定的字符串(如 "catbackward"、"split"、"view" 等),则根据操作类型创建对应的模块(例如 ops._ConcatOp 表示拼接操作,ops._SplitOp 表示拆分操作,ops._ReshapeOp 表示形状变化操作等)。
如果没有匹配到特定类型的操作,则默认将其视为 逐元素操作。
创建好模块后,会将 grad_fn 与新创建的模块存储到 gradfn2module 字典中,以便以后查找。

 如果 module 还没有在 module2node 字典中找到对应的节点,则创建一个新的节点 Node 对象。
该节点包含以下信息:
module: 关联的操作模块
grad_fn: 关联的梯度计算函数
name: 从 _module2name 字典中获取模块的名称,如果没有,则为 None。
如果该模块是自定义的修剪器(CUSTOMIZED_PRUNERS),则将节点类型设置为 CUSTOMIZED。
将新节点添加到 module2node 字典中,以便后续访问。
如果 module 已经有对应的节点,则直接使用已存在的节点

hasattr() 是 Python 内置的一个函数,用来检查一个对象是否具有指定的属性。
检查当前的 grad_fn(计算图中的节点)是否有 next_functions 属性。
grad_fn.next_functions 是一个可迭代对象,每个元素表示当前梯度函数(操作)依赖的输入(上游节点)。遍历 next_functions 列表中的每个元素,来处理每个输入。

如果 f[0] 为 None,表示该输入没有有效的梯度函数,因此跳过这个输入

这行代码检查 f[0](即当前输入的 grad_fn)是否有 name 属性,并且其名称是否包含 "accumulategrad"(表示该输入是一个叶子变量)。这种叶子变量通常对应于模型参数(如权重或偏置),它们不是由其他操作计算得到的,而是计算图中的输入

如果 f[0] 是叶子变量,进一步检查它是否属于未包装的参数(即 unwrapped_parameters)。
如果找到了匹配的参数,gradfn2module[f[0]] = p 将 grad_fn 映射到该参数(p)。同时,使用 self._module2name 为该参数生成一个名称 "UnwrappedParameter_j (shape)",并将其赋值为 grad_fn 的名称。
如果没有找到匹配的参数,跳过当前输入 
调用 create_node_if_not_exists(f[0]) 为输入 f[0] 创建一个节点

node.add_input(input_node, allow_dumplicated=False) 将当前的 input_node作为输入添加到 node中。allow_dumplicated=False 表示不允许重复连接相同的输入。
input_node.add_output(node, allow_dumplicated=False) 将 ndoe\作为输出添加到 input_node中。

f[0] 被添加到 processing_stack 中,表示该输入已经被处理

visited.add(grad_fn) 将当前的 grad_fn 标记为已访问,表示该节点已经被处理过。
visited_as_output_node.add(node) 将当前的 node 标记为已访问的输出节点,防止后续重复处理

对于没有包装的节点

最后返回模块和节点之间的关系

打个节点,下次再看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249042.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Stable Diffusion 3详解

🌺系列文章推荐🌺 扩散模型系列文章正在持续的更新,更新节奏如下,先更新SD模型讲解,再更新相关的微调方法文章,敬请期待!!!(本文及其之前的文章均已更新&…

零基础学安全--shell(8)脚本相互利用

目录 学习连接 脚本相互利用 脚本利用 利用脚本中的变量 重定向 输出重定向 错误输出 输入重定向 学习连接 声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探…

图书系统小案例

目前就实现了分页查询,修改,删除功能 这个小案例练习到了很多技能,比如前后端交互、异步请求、三层架构思想、后端连接数据库、配置文件、基础业务crud等等 感兴趣的小伙伴可以去做一个试试 准备工作 1、使用maven构建一个web工程 打开i…

深度理解进程的概念(Linux)

目录 一、冯诺依曼体系 二、操作系统(OS) 设计操作系统的目的 核心功能 系统调用 三、进程的概念与基本操作 简介 查看进程 通过系统调用获取进程标识符 通过系统调用创建进程——fork() 四、进程的状态 操作系统中的运行、阻塞和挂起 理解linux内核链表 Linux的进…

自媒体图文视频自动生成软件|03| 页面和结构介绍

代码获取方式在文本末尾🔚 *代码获取方式在文本末尾🔚 *代码获取方式在文本末尾🔚 *代码获取方式在文本末尾🔚 视频图片生成器 一个基于 Python 和 Web 的工具,用于生成带有文字和语音的视频以及图片。支持多种尺寸、…

STM32的一些知识技巧

STM32的一些知识技巧 STM32命名规则 查看代码编译所占用的flash和SRAM的大小 单位为字节(Byte) 1、使用keil编译结果进行计算 2、查看.map文件 STM32启动模式 主闪存启动地址为0x08000000 查看程序段/函数执行时间 global.prop文件 保存字体配置&…

我们来学mysql -- EXPLAIN之rows(原理篇)

EXPLAIN之rows 题记rows 题记 书接上文《 EXPLAIN之ref》2024美国大选已定,川普剑登上铁王座,在此过程中出谋划策的幕僚很重要,是他们决定了最终的执行计划在《查询成本之索引选择》中提到,explain的输出,就是优化器&…

【AI系统】昇腾 AI 核心单元

昇腾 AI 核心单元 本文将深入介绍昇腾 AI 处理器的核心单元——AI Core,以及其背后的达芬奇架构。昇腾 AI 处理器是华为针对 AI 领域设计的专用处理器,其核心 AI Core 采用了特定域架构(Domain Specific Architecture,DSA&#x…

Hadoop生态圈框架部署(九)- Hive部署

文章目录 前言一、Hive部署(手动部署)下载Hive1. 上传安装包2. 解压Hive安装包2.1 解压2.2 重命名2.3 解决guava冲突 3. 配置Hive3.1 配置Hive环境变量3.2 修改 hive-site.xml 配置文件3.3 配置MySQL驱动包3.3.1 下在MySQL驱动包3.3.2 上传MySQL驱动包3.…

RHCE——SELinux

SELinux 什么是SELinux呢?其实它是【Security-Enhanced Linux】的英文缩写,字母上的意思就是安全强化Linux的意思。 SELinux是由美国国家安全局(NSA)开发的,当初开发的原因是很多企业发现,系统出现问题的原因大部分都在于【内部…

Vue3的通灵之术Teleport

前言 近期Vue3更新了一些新的内容&#xff0c;我都还没有一个一个仔细去看&#xff0c;但是还是有必要去解读一下新内容的。就先从Teleport 开始吧。 官方对 Teleport 的解释是&#xff1a;<Teleport> 是一个内置组件&#xff0c;它可以将一个组件内部的一部分模板“传…

介绍一下atof(arr);(c基础)

hi , I am 36 适合对象c语言初学者 atof(arr)&#xff1b;是返回浮点数(double型)&#xff0c;浮点数数是arr数组中字符中数字 格式 #include<stdio.h> atof(arr); 返回值arr数组中的数 未改变arr数组 #include<stdio.h> //atof(arr) 返 <stdlib> int…

STM32 USART配置库函数

单片机学习&#xff01; 目录 一、USART配置函数 1.1 USART_DeInit函数 1.2 USART_Init函数 1.3 USART_StructInit函数 二、配置同步时钟输出函数 2.1 USART_ClockInit函数 2.2 USART_ClockStructInit函数 三、USART的外设与中断函数 3.1 USART_Cmd函数 3.2 USART_IT…

通俗理解人工智能、机器学习和深度学习的关系

最近几年人工智能成为极其热门的概念和话题&#xff0c;可以说彻底出圈了。但人工智能的概念在1955年就提出来了&#xff0c;可以说非常古老。我在上小学的时候《科学》课本上就有人工智能的概念介绍&#xff0c;至今还有印象&#xff0c;但那些年AI正处于“寒冬”&#xff0c;…

2024数学建模亚太赛【C题】赛题详细解析

目录 &#x1f4d1;一、竞赛时间 &#x1f5dd;️二、奖项设置 ✏️三、选题思路 &#x1f50d;阶段一&#xff1a;【数据预处理与探索性分析】 1.【数据清洗与预处理】 2.【探索性数据分析&#xff08;EDA&#xff09;】 &#x1f50d;阶段二&#xff1a;【时间序列建模…

数据结构 【堆实现】

上文提到堆是一种特殊的二叉树&#xff0c;其中它的父结点均不大于或者不小于其子结点的值。堆总是一棵完全二叉树。其中&#xff0c;堆的父节点全部不小于它的子结点时称为大堆&#xff0c;堆的父结点全部不大于其子结点的堆称为小堆。 堆可以由两种结构来实现&#xff0c;分别…

【AI绘画】Midjourney进阶:色调详解(下)

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: AI绘画 | Midjourney 文章目录 &#x1f4af;前言&#x1f4af;Midjourney中的色彩控制为什么要控制色彩&#xff1f;为什么要在Midjourney中控制色彩&#xff1f; &#x1f4af;色调纯色调灰色调暗色调 &#x1f4af…

[代码随想录Day24打卡] 93.复原IP地址 78.子集 90.子集II

93.复原IP地址 一个合法的IP地址是什么样的&#xff1a; 有3个’.分割得到4个数&#xff0c;每个数第一个数不能是0&#xff0c;不能含有非法字符&#xff0c;不能大于255。 这个是否属于合法IP相当于一个分割问题&#xff0c;把一串字符串分割成4部分&#xff0c;分别判断每…

“harmony”整合不同平台的单细胞数据之旅

其实在Seurat v3官方网站的Vignettes中就曾见过该算法&#xff0c;但并没有太多关注&#xff0c;直到看了北大张泽民团队在2019年10月31日发表于Cell的《Landscap and Dynamics of Single Immune Cells in Hepatocellular Carcinoma》&#xff0c;为了同时整合两类数据&#xf…

贴代码PasteForm框架之框架核心帮助类PasteFormHelper说明

简介 PasteForm是贴代码推出的 “新一代CRUD” &#xff0c;基于ABPvNext&#xff0c;目的是通过对Dto的特性的标注&#xff0c;从而实现管理端的统一UI&#xff0c;借助于配套的PasteBuilder代码生成器&#xff0c;你可以快速的为自己的项目构建后台管理端&#xff01;目前管…