YOLOv8-对注意力机制模型进行通道剪枝-同时实现涨点和轻量化【附代码】

news2024/9/20 19:58:51

文章目录

  • 前言
  • 视频效果
  • 文章概述
  • 必要环境
  • 一、训练自己的模型
    • 1、 训练命令
    • 2、 训练参数解析
  • 二、模型剪枝
    • 1、 对训练好的模型将进行剪枝
    • 2、 剪枝代码详解
      • 1.解析命令行参数
      • 2. 定义剪枝函数
      • 3. 定义剪枝结构
      • 4. 更新注意力机制
      • 5. 保存更新后的模型
      • 6. 主函数
  • 三、剪枝后的训练
    • 运行命令如下
  • 四、实验指标对比
  • 五、剪枝前后效果对比
  • 六、完整代码获取
  • 总结


前言

在上期博客中,我们实现了对YOLOv10模型的结构化通道剪枝,本篇文章将介绍如何对增加了MCA注意力机制的YOLOv8模型进行通道剪枝,并详细解读每个参数和模块的作用。
上期博客地址:YOLOv10结构化通道剪枝【附代码】


视频效果

b站链接:魔改YOLOv8 在参数量下降51.3%的情况下涨点1% (KITTI验证集)


文章概述

本篇博客将详细介绍如何对yolov8注意力机制模型进行通道剪枝,具体步骤包括参数解析、剪枝代码讲解、fine-tune训练,最后将对比剪枝前后模型在KITTI数据集上的表现,包括MAP、参数量和FPS等指标,以验证剪枝效果


必要环境

  1. 配置yolov10环境 可参考往期博客(v8和v10环境配置方法可通用)
    地址:搭建YOLOv10环境 训练+推理+模型评估

  2. 安装torch-pruning 0.2.7版本,安装命令如下

    pip install torch-pruning==0.2.7
    
  3. 结构化剪枝论文地址
    地址:Pruning Filters for Efficient ConvNets

  4. MCA注意力机制论文地址
    地址:Multidimensional collaborative attention in deep convolutional neural networks for image recognition


一、训练自己的模型

1、 训练命令

python 1_yolov8_train.py --mode train_ch --yaml_path yolov8n.yaml --epoch 200 --batch 32 --model_path ''

运行效果
在这里插入图片描述
可以看到正常训练时会打印模型在yaml文件中定义的网络结构

2、 训练参数解析

# 解析命令行参数
parser = argparse.ArgumentParser(description='Train or validate YOLO model.')
parser.add_argument('--mode', type=str, default='val', choices=['train_ori', 'train_ch', 'val'],
                    help='Mode of operation.')
parser.add_argument('--yaml_path', type=str, default='yolov8n.yaml', help='Path to YAML file.')
parser.add_argument('--model_path', type=str, default=r'runs/kitti_ori/weights/best.pt', help='Path to model file.')
parser.add_argument('--data_path', type=str, default='./data.yaml', help='Path to data file.')
parser.add_argument('--epoch', type=int, default=200, help='Number of epochs.')
parser.add_argument('--batch', type=int, default=16, help='Batch size.')
parser.add_argument('--workers', type=int, default=8, help='Number of workers.')
parser.add_argument('--device', type=str, default='0', help='Device to use.')
parser.add_argument('--name', type=str, default='', help='Name data file.')
args = parser.parse_args()

参数详解:

  1. –mode: 用于指定操作模式
    可选值为train_ori、train_ch和val。train_ori用于训练原始模型,train_ch用于训练改进后的模型(如增加注意力机制或增加检测头),val用于验证模型并计算精度指标

  2. –yaml_path: 指定改进网络结构的YAML文件路径
    当选择训练改进后的模型时,需要提供相应的网络结构文件路径,如训练带有MCA注意力机制的模型时,此处填写相应的YAML文件路径

  3. –model_path: 指定模型文件路径
    当mode不等于val时,该参数为预训练模型的路径(如训练8n模型时,此处填写yolov8n.pt路径)
    当mode等于val时,该参数为训练好的模型路径,用于计算指标,通常保存在runs目录下

  4. –data_path: 指定数据集文件路径
    该参数用于提供数据集的路径,对应一个YAML文件

  5. –epoch: 指定训练的轮数
    默认值为200,表示模型的训练轮次

  6. –batch: 指定批次大小
    默认值为16,表示每次训练迭代中所处理的样本数量

  7. –workers: 指定工作线程数
    默认值为8,表示用于数据加载的工作线程数量,windows系统这里改为0

  8. –device: 指定使用的设备
    默认值为0,表示使用的GPU设备编号

  9. –name: 指定保存模型文件夹的名称

二、模型剪枝

1、 对训练好的模型将进行剪枝

运行命令如下

python 3_yolov8_pruning.py --model_path weights/kitti_baseline/weights/best.pt --prune_type l1 --prune_ratio 0.4

运行效果
在这里插入图片描述

运行成功后会输出剪枝后的网络结构,以及剪枝前后模型的参数量对比

2、 剪枝代码详解

1.解析命令行参数

解析命令行参数的,其方便各位在命令行中指定模型路径、剪枝策略以及剪枝比例等参数

# 解析命令行参数
def parse_args():
    parser = argparse.ArgumentParser(description="Prune YOLOv8 model.")
    parser.add_argument("--model_path", type=str,
                        default=r"weights/kitti_baseline/weights/best.pt",
                        help="Path to the YOLOv8 model.")
    parser.add_argument("--prune_type", type=str, default="l2", choices=["l1", "l2", "random"],
                        help="Pruning strategy to use.")
    parser.add_argument("--prune_ratio", type=float, default=0.4, help="Pruning ratio.")
    args = parser.parse_args()
    return args

参数详解:

  1. –model_path: 指定需要剪枝的模型路径
  2. –prune_type: 指定剪枝策略,可选方案为 l1, l2, random,默认使用 l1策略
  3. –prune_ratio: 指定剪枝比例,默认值为0.4,表示对定义的卷积层减掉40%的通道数

2. 定义剪枝函数

用于根据指定的修剪策略和比例对给定的模型进行修剪

def prune_model(model, prune_type, prune_ratio, input_tensor):
    strategy = {
        'l1': tp.strategy.L1Strategy(),
        'l2': tp.strategy.L2Strategy(),
        'random': tp.strategy.RandomStrategy()
    }.get(prune_type, tp.strategy.RandomStrategy())

    dependency_graph = tp.DependencyGraph().build_dependency(model, example_inputs=input_tensor)
    included_layers = get_included_layers(model)

    original_params = tp.utils.count_params(model)
    pruning_plans = [
        dependency_graph.get_pruning_plan(m, tp.prune_conv, idxs=strategy(m.weight, amount=prune_ratio))
        for m in model.modules() if isinstance(m, nn.Conv2d) and m in included_layers
    ]

关键步骤详解:

  1. 策略选择
    根据 prune_type 参数, 选择对应的剪枝策略,如果 prune_type 不是预定义的值, 则默认使用随机剪枝策略

  2. 构建依赖图
    使用 tp.DependencyGraph().build_dependency 函数构建模型的依赖关系图, 以便后续进行剪枝操作

  3. 获取包含的层
    使用 get_included_layers 函数获取需要进行剪枝的层, 即模型中的 nn.Conv2d 层

  4. 计算原始参数数量
    使用 tp.utils.count_params 函数计算模型的原始参数数量

  5. 制定剪枝计划
    对于每个需要剪枝的 nn.Conv2d 层, 使用对应的剪枝策略计算剪枝的索引, 并生成剪枝计划

3. 定义剪枝结构

从指定模型中, 找出所有可以进行剪枝操作的层, 并将它们添加到 included_layers 列表中

def get_included_layers(model):
    included_layers = []

    for layer in model.model:
        if isinstance(layer, Conv):
            included_layers.append(layer.conv)
                ...
                
        if isinstance(layer, Detect):
                ...
    return included_layers

关键模块详解:

  1. model: 指定yolov8模型,函数将遍历这个模型的层来识别可剪枝的部分
  2. included_layers: 用于存储可以进行剪枝操作的层,函数会将这些层添加到这个列表中
  3. 定义模型中不同类型的层,函数会根据层的类型采取不同的处理方式,将可剪枝的部分添加到 included_layers 列表中

4. 更新注意力机制

由于torch_pruning中的某些bug,剪枝后会使注意力机制中某些模块的通道数变为负数,为了确保剪枝后的网络能够正确工作,我们需要更新这些层的通道数

def replace_conv_macayer(original_layer, new_in_channels, new_out_channels):
    # 获取原始层的参数
    original_weight = original_layer.conv.weight.data
    original_bias = original_layer.conv.bias.data if original_layer.conv.bias is not None else None

    # 创建一个新的卷积层
    new_conv_layer = nn.Conv2d(in_channels=new_in_channels, out_channels=new_out_channels,
                               kernel_size=original_layer...)
    # 复制权重
    ...
    return new_conv_layer

关键模块详解:

  1. original_layer: 原始的卷积层。这是一个包含卷积层的对象,通常是一个网络中的某个层。
  2. new_in_channels: 新的输入通道数。
  3. new_out_channels: 新的输出通道数。
  4. 返回值:最终会返回一个新的卷积层,该层具有更新后的输入和输出通道数,并且尽可能保留了原始层的权重和偏置。

5. 保存更新后的模型

**剪枝操作完成后,我们需要将剪枝后的模型保存,以便后续使用 **

# 保存更新后的模型
def save_pruned_model(model, ckpt, prune_type):
    param_dict = {
        'model': model,
        'ema': ckpt['ema'],
         ...
    }
    torch.save(param_dict, f'prune_model_{prune_type}.pt')

参数详解:

  1. model:剪枝后的模型**
  2. ckpt:模型训练状态和相关参数的字典,需要将必要部分写入到剪枝模型中**
  3. prune_type:剪枝类型,用于命名保存的模型文件**

6. 主函数

定义主函数,整合上述各个步骤,实现完整的剪枝流程

def main():
    args = parse_args()
    # 加载模型
    yolov8 = YOLO(args.model_path)
    # 使模型参数可训练
    for para in model.parameters():
        para.requires_grad = True
        
    pruned_model, original_params = prune_model(model, args.prune_type, args.prune_ratio, input_tensor)
    # 更新模型中的注意力层
    update_model_attention_layers(pruned_model)
    # 保存更新后的模型
    save_pruned_model(pruned_model, ckpt, args.prune_type)
    pruned_params = tp.utils.count_params(model)
    percentage_reduction = ((original_params - pruned_params) / original_params) * 100
    logger.info(
        f"Params: {original_params * 4 / 1024 / 1024:.2f} MB => {pruned_params * 4 / 1024 / 1024:.2f} MB (Reduction: {percentage_reduction:.2f}%)")

关键模块解读:
1. parse_args():解析命令行参数。
2. YOLO(args.model_path):加载YOLOv8模型
3. prune_model():执行剪枝操作
4. save_pruned_model():保存剪枝后的模型
5. 计算剪枝前后参数的变化,并打印模型信息和参数减少的百分比

三、剪枝后的训练

运行命令如下

python 4_yolov8-finetune.py --finetune --epochs 200 --batch_size 16

运行效果
在这里插入图片描述
可以看到剪枝后训练不会打印模型在yaml文件中定义的网络结构

四、实验指标对比

如下表:

模型MAPPRImgszParam
YOLOv8n原模型86.796.874.864011.47
YOLOv8n+MCA88.495.778.564011.47
YOLOv8n+MCA+剪枝87.797.276.56405.59
  1. 由此可见在KITTI验证集上,MCA注意力机制+剪枝可以做到在参数量下降51%的情况下涨点1%
  2. 验证集是通过留出法将训练集按9:1比例进行划分所得

五、剪枝前后效果对比

剪枝前:
在这里插入图片描述
剪枝后:
在这里插入图片描述

实验设备为RTX2060,如上图所示 剪枝后的模型FPS更高,推理速度更快

六、完整代码获取

链接:YOLOv8结合MCA注意力机制+通道剪枝-同时实现涨点和轻量化


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1885383.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows 11 安装 安卓子系统 (WSA)

How to Install Windows Subsystem for Android (WSA) on Windows 11 新手教程:如何安装Windows 11 安卓子系统 说明 Windows Subsystem for Android 或 WSA 是由 Hyper-V 提供支持的虚拟机,可在 Windows 11 操作系统上运行 Android 应用程序。虽然它需…

c++读取文件时出现中文乱码

原因:UTF-8格式不支持汉字编码 解决:改成ANSI,因为ANSI编码支持汉字编码

生成式人工智能将如何改变网络可访问性

作者:Matthew Adams 受 Be My Eyes 和 OpenAI 启发的一项实验,尝试使用 ChatGPT 4o 实现网页无障碍 在 Elastic,我们肩负着一项使命,不仅要构建最佳的搜索驱动型 AI 平台,还要确保尽可能多的人喜欢使用该平台。我们相…

深入剖析vLLM:大模型计算加速系列之调度器策略探索

原文: 图解大模型计算加速系列:vLLM源码解析2,调度器策略(Scheduler) 目录 收起 前期提要与本期导览 一、入口函数 二、SequenceGroup 2.1 原生输入 2.2 SequenceGroup的作用 2.3 SequenceGroup的结构 三、add_request()&#xff1a…

[python][Anaconda]使用jupyter打开F盘或其他盘文件

jupyter有一个非常不好的体验,就是不能在界面切换到其他盘来打开文件。 使用它,比较死板的操作是要先进入文件目录,再运行jupyter。 以Windows的Anaconda安装了jupyter lab或jupyter notebook为例。 1,先运行Anaconda Prompt 2&…

儿童房间灯哪个牌子的好?几款儿童房间灯具品牌分享

对于视力正处于发育阶段的儿童而言,台灯已不仅仅是一个简单的照明工具。它不仅驱散夜幕下的阴霾,还能为儿童的眼部保驾护航。一款优质的护眼台灯更是不可或缺的守护者。然而,面对市场上琳琅满目的选择,怎样选出一款合适的护眼台灯…

​Stable Diffusion史上最全插件,已打包整理,12个常用插件你肯定用得上!

还在于有丰富的第三方插件,即我们在安装部署之后安装汉化插件的界面 插件安装方式可以是“可下载->加载扩展列表”,然后从列表选择或搜索插件下载,或直接选择“从网站安装”,填写插件的git仓库地址。一般我们从扩展列表搜索即可…

【Python】已解决:pymssql._pymssql.OperationalError 关于关键字‘distinct’的语法错误

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决:pymssql._pymssql.OperationalError 关于关键字‘distinct’的语法错误 一、分析问题背景 在使用pymssql库与SQL Server数据库进行交互时,有时会遇到…

WPF在.NET9中的重大更新:Windows 11 主题

在2023年的2月20日,在WPF的讨论区,WPF团队对路线的优先级发起了一次讨论。 对三个事项发起了投票。 第一个是Windows 11 主题 第二个是更新的控件 第三个是可空性注释 最终Windows 11 主题得票最高,WPF团队2023-2024的工作优先级就是Windows…

UE4_材质_水体的反射与折射制作_Ben教程

在这个教程中,将制作水的反射和折射,上个教程,我们主要讲了制作水涟漪(水面波纹)和水滴法线混合,水深计算,我们首先要谈的是反射和产生折射的问题。我们将所有从干扰从场景中分离出去&#xff0…

微信小程序 canvas 处理图片的缩放移动旋转问题

这里使用到了一个插件&#xff0c;canvas-drag&#xff0c;来实现大部分功能的 上效果 直接上代码吧~ wxml <div class"container"><canvas-drag id"canvas-drag" graph"{{graph}}" width"700" height"750" ena…

页面加载503 Service Temporarily Unavailable异常

最近发现网页刷新经常503&#xff0c;加载卡主&#xff0c;刷新页面就正常了。 研究之后发现是页面需要的js文件等加载失败了。 再研究之后发现是nginx配置的问题。 我之前为了解决一个漏洞检测到目标主机可能存在缓慢的HTTP拒绝服务攻击 把nginx的连接设置了很多限制&#…

JSONpath语法怎么用?

JSONPath 可以看作定位目标对象位置的语言&#xff0c;适用于 JSON 文档。 JSONPath 与 JSON 的 关系相当于 XPath 与 XML 的关系&#xff0c; JSONPath 参照 XPath 的路径表达式&#xff0c;提供了描述 JSON 文档层次结构的表达式&#xff0c;通过表达式对目标…

点云处理实操 点云平面拟合

目录 一、什么是平拟合 二、拟合步骤 三、数学原理 1、平面拟合 2、PCA过程 四、代码 一、什么是平拟合 平面拟合是指在三维空间中找到一个平面,使其尽可能接近给定的点云。最小二乘法是一种常用的拟合方法,通过最小化误差平方和来找到最优的拟合平面。 二、拟合步骤…

【Python】已解决:ERROR: No matching distribution found for JPype

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决&#xff1a;ERROR: No matching distribution found for JPype 一、分析问题背景 在Python开发中&#xff0c;有时我们需要使用Java库来扩展功能或实现某些特定任务。JPype…

有手就行,轻松本地部署 Llama、Qwen 大模型,无需 GPU

用 CPU 也能部署私有化大模型&#xff1f; 对&#xff0c;没错&#xff0c;只要你的电脑有个 8G 内存&#xff0c;你就可以轻松部署 Llama、Gemma、Qwen 等多种开源大模型。 非技术人员&#xff0c;安装 Docker、Docker-compose 很费劲&#xff1f; 不用&#xff0c;这些都不…

方法重载与重写的区别

1.方法重载和重写都是实现多态的方式&#xff0c;区别在于重载是编译时多态&#xff0c;重写是运行时多态。 2.重载是在同一个类中&#xff0c;两个方法的方法名相同&#xff0c;参数列表不同&#xff08;参数类型、顺序、个数&#xff09;&#xff0c;与方法返回值无关&#x…

springboot种草好物app-计算机毕业设计源码13151

摘要 随着电子商务的快速发展和智能手机的普及&#xff0c;越来越多的用户选择通过移动应用程序进行商品浏览、购买和分享体验。种草好物App作为一个专注于商品推荐和购物体验的平台&#xff0c;具有广泛的应用前景和商业价值。本研究旨在构建一个功能丰富、性能稳定的种草好物…

(vue)el-tabs选中最后一项后更新数据后无法展开

(vue)el-tabs选中最后一项后更新数据后无法展开 效果&#xff1a; 原因&#xff1a;选中时绑定的值在数据更新后找不到 思路&#xff1a;更新数据时把选中的v-model的属性赋为初始值 写法&#xff1a; <el-form-item label"字段选择"><el-tabsv-model&qu…

【计算机网络】传输层(作业)

1、OSI参考模型中&#xff0c;提供端到端的透明数据传输服务、差错控制和流量控制的层是&#xff08;C&#xff09;。 A. 物理层B. 网络层C. 运输层D. 会话层 2、运输层为&#xff08;B&#xff09;之间提供逻辑通信。 A. 主机B. 进程C. 路由器D. 操作系统 3、运输层面向连接…