Yolov5 中添加Network Slimming剪枝--稀疏训练部分

news2024/11/24 6:41:38

前言:Network Slimming剪枝过程让如下

1. 稀疏化

2. 剪枝

3. 反复迭代这个过程

 一、稀疏化:

通过Network Slimming 的核心思想是:添加L1正则来约束BN层系数,从而剪掉那些贡献比较小的通道channel

原理如下:BN层的计算是这样的:

上边介绍了,Network Slimming的核心思想是剪掉那些贡献比较小的通道channel它的做法是从BN层下手。BN层的计算公式如下:

通过BN层的计算公式可以看出每个channe的Zout的大小和系数γ正相关,因此我们可以拿掉哪些γ-->0的channel,但是由于正则化,我们训练一个网络后,bn层的系数是正态分布的。这样的话,0附近的值则很少,那剪枝的作用就很小了。因此要先给BN层加上L1正则化进行一步稀疏训练(为什么要用L1正则化可以看该博客:l1正则与l2正则的特点是什么,各有什么优势? - 知乎)。

为BN层加入L1正则化后,损失函数公式为:

上面第一项是正常训练的loss函数,第二项是约束对于L1正则化,g(s)=|s|,λ是正则系数,引入L1正则来控制γ, 要把稀疏表达加在γ 上, 得到每个特征的重要性 λ

- 每个通道的特征对应的权重是 γ 
- 稀疏表达也是对 γ 来说的, 所以正则化系数 λ 也是针对  γ, 而不是 W
-  稀疏化后, 做γ 值的筛选

因此在进行反向传播时候:𝐿′=∑𝑙′+𝜆∑𝑔′(𝛾)=∑𝑙′+𝜆∑|𝛾|′=∑𝑙′+𝜆∑𝑠𝑖𝑔𝑛(𝛾)

那如何把程序加到yolov5呢?

在yolov5 train.py的程序中找到反向传播部分程序:

1.1 稀疏训练核心代码

将scaler.scale(loss).backward()注释,并添加下方代码:

  代码如下:

 # Backward
            # scaler.scale(loss).backward()
            loss.backward()
            # # ============================= sparsity training ========================== #
            srtmp = opt.sr*(1 - 0.9*epoch/epochs)  # opt.sr=0.0001 随着epoch增多,把srtmp减小
            if opt.st:  # '默认是true  train with L1 sparsity normalization  
                ignore_bn_list = []
                for k, m in model.named_modules():
                    # print('name: {}, module: {}'.format(k,m))
                    if isinstance(m, Bottleneck):
                        if m.add:
                            ignore_bn_list.append(k.rsplit(".", 2)[0] + ".cv1.bn")
                            ignore_bn_list.append(k + '.cv1.bn')
                            ignore_bn_list.append(k + '.cv2.bn')
                    if isinstance(m, nn.BatchNorm2d) and (k not in ignore_bn_list):
                        # L1 regulation formulate: λΣ|γ|
                        # |x|' = {-1,1}
                        # L1 grad: (λΣ|γ|)'=λ * Σsign(γ)
                        # BN(γ,β)
                        m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data))  # L1
                        m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.bias.data))  # L1
            # # ============================= sparsity training ========================== 
            # if ni - last_opt_step >= accumulate:
            #     scaler.step(optimizer)  # optimizer.step
            #     scaler.update()
            #     optimizer.zero_grad()
            #     if ema:
            #         ema.update(model)
            #     last_opt_step = ni
            optimizer.step()
                # scaler.step(optimizer)  # optimizer.step
                # scaler.update()
            optimizer.zero_grad()
            if ema:
                ema.update(model)

其中sr,st 需要添加参数

parser.add_argument('--st', action='store_true',default=True, help='train with L1 sparsity normalization')
parser.add_argument('--sr', type=float, default=0.0002, help='L1 normal sparse rate')

其中需要注意的点1:

红框处程序是因为这里并没有选择所有的bn层进行裁剪,这里选择去除那些有shortcut的Bottleneck层(对应代码中m.add = True),主要是为了保证shortcut和残差层channel一样可以add。

--------------------在这里我曾经有这样的疑问:(该部分可以不看) -----------------------------------------

这两个if 我能理解最终目的是:去除那些有shortcut的Bottleneck层,但是为什么要有 +cv1.bn等等那三步呢?不能直接把k添加到 ignore_bn_list吗?

再说了添加了之后,加入ignore_bn_list的名字就变了呀,此时再运行下一个if的时候k是不在ignore_bn_list

为什么不能改成指令:

if isinstance(m, Bottleneck):

                        if m.add:

                            ignore_bn_list.append(k)

if isinstance(m, nn.BatchNorm2d) and (k in ignore_bn_list):

 后来我明白了,这里是为了不对Bottlenack中的BatchNorm2d加正则化,因上述改名字的那个步骤其实是找的该Bottleneck下面BatchNorm2d 的名字。比如我断点调试了一下:

其中名为Module.model.2.m.0的模型

 其下的BatchNorm2d的名字分别如下:

那再回头看看那部分程序,就理解了。
-------------------------------------------------------结束--------------------------------------------------------------------

注意的点2:

yolov5会采用自动混合精度训练,因此需要把改成fp32,方法如下:
修改train.py
1. 注释掉    # scaler = amp.GradScaler(enabled=cuda)
2. 把train.py中的.half 都去掉
具体为:
  # Anchors
            if not opt.noautoanchor:
                check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
            # model.half().float()
            model.float()  # pre-reduce anchor precision
# Save model
            if (not nosave) or (final_epoch and not evolve):  # if save
                ckpt = {'epoch': epoch,
                        'best_fitness': best_fitness,
                        # 'model': deepcopy(de_parallel(model)).half(),
                        'model': deepcopy(de_parallel(model)),
                        # 'ema': deepcopy(ema.ema).half(),
                        'ema': deepcopy(ema.ema),
                        'updates': ema.updates,
                        'optimizer': optimizer.state_dict(),
                        'wandb_id': loggers.wandb.wandb_run.id if loggers.wandb else None,
                        'date': datetime.now().isoformat()}
 if RANK in [-1, 0]:
        LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
        for f in last, best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
                if f is best:
                    LOGGER.info(f'\nValidating {f}...')
                    results, _, _ = val.run(data_dict,
                                            batch_size=batch_size // WORLD_SIZE * 2,
                                            imgsz=imgsz,
                                            # model=attempt_load(f, device).half(),
                                            model=attempt_load(f, device),
                                            iou_thres=0.65 if is_coco else 0.60,  # best pycocotools results at 0.65
                                            single_cls=single_cls,
                                            dataloader=val_loader,
                                            save_dir=save_dir,
                                            save_json=is_coco,
                                            verbose=True,
                                            plots=True,
                                            callbacks=callbacks,
                                            compute_loss=compute_loss)  # val best model with plots

以上就将对BN层添加L1正则化的程序加好了,核心思想就是修改反向传播的梯度。

1.2 查看稀疏训练效果

如果想查看系数训练的效果,可加入下方程序:

  # =============== show bn weights ===================== #
        module_list = []
        module_bias_list = []
        for i, layer in model.named_modules():
            if isinstance(layer, nn.BatchNorm2d) and i not in ignore_bn_list:
                bnw = layer.state_dict()['weight']
                bnb = layer.state_dict()['bias']
                module_list.append(bnw)
                module_bias_list.append(bnb)
                # bnw = bnw.sort()
                # print(f"{i} : {bnw} : ")
        size_list = [idx.data.shape[0] for idx in module_list]

        bn_weights = torch.zeros(sum(size_list))
        bnb_weights = torch.zeros(sum(size_list))
        index = 0
        for idx, size in enumerate(size_list):
            bn_weights[index:(index + size)] = module_list[idx].data.abs().clone()
            bnb_weights[index:(index + size)] = module_bias_list[idx].data.abs().clone()
            index += size

        # print("bn_weights:", torch.sort(bn_weights))
        # print("bn_bias:", torch.sort(bnb_weights))
        # tb_writer.add_histogram('bn_weights/hist', bn_weights.numpy(), epoch, bins='doane')
        # tb_writer.add_histogram('bn_bias/hist', bnb_weights.numpy(), epoch, bins='doane')
        loggers.tb.add_histogram('bn_weights/hist', bn_weights.numpy(), epoch, bins='doane')
        loggers.tb.add_histogram('bn_bias/hist', bnb_weights.numpy(), epoch, bins='doane')

将其加到:一个batch训练结束之后的程序后边就好

我在tensorboard中的图:

 纵轴是epoch,横轴是权重,可以看到我一共进行了100轮稀疏训练,我的项目中非bottoltneck中的bn层加起来参数大概有8000多个,那可以看到在49个epoch的时候,0附近的权重已经有5000多个了。那接下来我可以设置60%的剪枝率,把它们都剪掉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/66523.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何理解UML2.5.1(02篇)

为了避免使大家产生畏难情绪,本节先讲一个相对简单又相对普遍的问题。先看UML2.5.1中第13.2.3.5的如下内容: A Behavior shall be the method for no more than one BehavioralFeature, called its specification. 翻译过来就是: 一个行为应该…

[附源码]Python计算机毕业设计SSM家政服务系统(程序+LW)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

2093197-94-3,DBCO-BODIPY FL,二苯并环辛炔-BODIPY FL点击化学染料环辛炔

【中文名称】氟化硼二吡咯二苯并环辛炔,二苯并环辛炔-BODIPY FL 【英文名称】 DBCO-BODIPY FL,BDP FL DBCO 【结 构 式】 【CAS号】2093197-94-3 【分子式】C32H29BF2N4O2 【分子量】550.42 【基团部分】DBCO部分 【纯度标准】95% 【包装规格】5g&#x…

【校招VIP】线上实习 推推 书籍详情模块 产品脑图周最佳

【推推】主要是为校招设计的小说一更新就通知的项目,每个模块都具有亮点和难点,项目表现为手机网站应用,可嵌入小程序或APP中。 恭喜来自 太原理工大学 的 星晚🌟 同学获得本周线上实习【推推】第一期 书籍详情模块 产品脑图设计…

kubernetes编排文件示例

kubernetes编排文件示例 编排文件生成网址:https://www.kubebiz.com/ mysql单机 需要一个配置文件,内容不会就用默认的即可 my.cnf [mysqld]pid-file /var/run/mysqld/mysqld.pid socket /var/run/mysqld/mysqld.sock datadir /var…

Python制作GUI学生管理系统,不会的看这里

前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 欢迎观看本篇文章呀~不管你是学生还是工作人 我相信你进来了你就是想实现这个案例的 学会以后,还可以去接一些小小的外包,又是挣钱的一天~ 那么就开始实现吧!python制作GUI 学生管理系…

Curve 块存储应用实践 -- iSCSI

Curve 是云原生计算基金会 (CNCF) Sandbox 项目,是网易数帆发起开源的高性能、易运维、云原生的分布式存储系统。 为了让大家更容易使用以及了解 Curve,我们期望接下来通过系列应用实践文章,以专题的形式向大家展示 Curve。 本篇文章是Curv…

activiti框架搭建及问题记录

activiti应用什么是activitiactiviti配置首先创建项目配置pom依赖配置文件那么审批(流程)怎么创建呢?流程启动任务处理activiti问题分享数据库创建问题activiti事件监听器没有对象的问题什么是activiti activiti是一个业务流程管理的框架&am…

LeetCode中等题之使括号有效的最少添加

题目 只有满足下面几点之一,括号字符串才是有效的: 它是一个空字符串,或者 它可以被写成 AB (A 与 B 连接), 其中 A 和 B 都是有效字符串,或者 它可以被写作 (A),其中 A 是有效字符串。 给定一…

开发工具系列IDEA:配置注释自动生成

一、类、接口、枚举配置&#xff0c;进入idea后&#xff0c;依次打开 File -> Settings -> Editor -> File and Code Templates -> Files /*** FileName: ${NAME}* Author: ${USER}* Date: ${DATE} ${TIME}* Description: ${DESCRIPTION}* History:* <aut…

中国电信移动物联网发展成果与创新实践 ,干货满满

近日&#xff0c;首届移动物联网大会&#xff08;2022&#xff09;&#xff08;以下简称“大会”&#xff09;在江苏省无锡市举办。本次大会由工信部指导&#xff0c;中国信息通信研究院&#xff08;以下简称“中国信通院”&#xff09;、中国通信学会、无锡市人民政府、人民邮…

产品工作流| 项目评估

一、什么是项目评估 根据已有的公开招标书要求&#xff0c;销售侧拿到招标要求&#xff0c;让研发评估项目。 1、需求满足度评估。 2、需求开发项&#xff0c;以及成本评估。 3、总结项目评估。 二、项目评估流程 材料依据&#xff1a; 1、投标材料。 2、项目评估表&#x…

路由器,集线器,交换机,网桥,光猫有啥区别?

网络分层 网线替代了上面的灰色部分&#xff0c;实现物理层互联。 如果想要两台电脑互联成功&#xff0c;还需要确保每一层所需要的步骤都要做到位&#xff0c;这样数据才能确保正确投送并返回。 从数据链路层到物理层&#xff0c;数据会被转为01比特流。 此时需要把比特流传…

【软件测试】小陈她的测试追梦之路,实习开端到测试第一人......

目录&#xff1a;导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09;前言 小陈&#xff1a;我…

中华黄金·金生态合伙人颁奖典礼在珠海站开幕完美收官!!

11月18-19日在广东珠海举行&#xff0c;近百位合伙人亲临现场&#xff0c;强者能人共聚天下&#xff0c;中华黄金合伙人&#xff0c;强强联手引爆市场&#xff0c;汇聚一堂。 本次活动以“金生态”为主题。CNG金生态是中华黄金集团旗下平台&#xff0c;运用WEB3.0核心技术聚合了…

Flink CDC入门实践--基于 Flink CDC 构建 MySQL 和 Postgres 的 Streaming ETL

文章目录前言1.环境准备2.准备数据2.1 MySQL2.2 postgres3.启动flink和flink sql client3.1启动flink3.2启动flink SQL client4.在flink SQL CLI中使用flink DDL创建表4.1开启checkpoint4.2对于数据库中的表 products, orders, shipments&#xff0c; 使用 Flink SQL CLI 创建对…

iOS开发之iOS15.6之后拉流LFLiveKit,画面模糊及16.1马赛克问题

更新了iOS15.6系统后&#xff0c;发现拉取LFLiveKit进行直播的流&#xff0c;竟然是这样的&#xff1a; 模糊不清&#xff0c;于是思考是什么原因导致的。 1、是不是拉流端出现的问题&#xff1f; 使用安卓拉取iOS的直播流&#xff0c;是同样的效果&#xff0c;又考虑到两端使…

【DL】Windows 10系统下安装TensorRT教程

Windows 10系统下安装TensorRT教程(手把手教程): Windows 10系统下安装TensorRT教程: 1.下载 https://developer.nvidia.com/nvidia-tensorrt-download EA 版本代表抢先体验(在正式发布之前)。 GA 代表通用性。表示稳定版,经过全面测试。 TensorRT、cuda、cudnn各版本…

与目前主流的消费返利平台对比,共享购模式有什么优势呢?

大家好&#xff0c;我是林工&#xff0c;之前几期内容都有介绍过共享购的商业模式&#xff0c;同时大家应该都对消费返利这方面有所了解。今天给大家分享一下整套模式的优劣势。 什么是消费返利&#xff1f;消费返利是互联网常见的一个商业模式&#xff0c;是指互联网平台将自…

毕业设计 基于STM32与wifi的天气预报网时钟系统 - 物联网 单片机

文章目录0 前言1 设计内容2 软件设计3 关键代码4 最后0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff0c;这两年不断有学弟学妹告诉学长自己做的项目系统达不…