YOLOv5改进 | 添加ECA注意力机制 + 更换主干网络之ShuffleNetV2

news2025/1/18 9:10:41

前言:Hello大家好,我是小哥谈。本文给大家介绍一种轻量化部署改进方式,即在主干网络中添加ECA注意力机制和更换主干网络之ShuffleNetV2,希望大家学习之后,能够彻底理解其改进流程及方法~!🌈 

     目录

🚀1.基础概念

🚀2.添加位置

🚀3.添加步骤

🚀4.改进方法

💥💥步骤1:common.py文件修改

💥💥步骤2:yolo.py文件修改

💥💥步骤3:创建自定义yaml文件

💥💥步骤4:修改自定义yaml文件

💥💥步骤5:验证是否加入成功

💥💥步骤6:修改默认参数

🚀1.基础概念

ECA注意力机制:

ECA注意力机制是一种用于提升卷积神经网络特征表示能力的方法。它通过嵌入式通道注意力模块,在保持高效性的同时,引入了通道注意力机制。具体来说,ECA注意力机制在通道维度上增加了注意力机制,以提升特征表示的能力。与SE注意力机制不同的是,ECA注意力机制只包含一个操作——excitation,而没有squeeze操作。这使得ECA注意力机制更加轻量级,适用于计算资源有限的场景。

ECA的结构主要分为两个部分:通道注意力模块和嵌入式通道注意力模块。

🍀(1)通道注意力模块

通道注意力模块是ECA的核心组成部分,它的目标是根据通道之间的关系,自适应地调整通道特征的权重。该模块的输入是一个特征图(Feature Map),通过全局平均池化得到每个通道的全局平均值,然后通过一组全连接层来生成通道注意力权重。这些权重被应用于输入特征图的每个通道,从而实现特征图中不同通道的加权组合。最后,通过一个缩放因子对调整后的特征进行归一化,以保持特征的范围。

🍀(2)嵌入式通道注意力模块

嵌入式通道注意力模块是ECA的扩展部分,它将通道注意力机制嵌入到卷积层中,从而在卷积操作中引入通道关系。这种嵌入式设计能够在卷积操作的同时,进行通道注意力的计算,减少了计算成本。具体而言,在卷积操作中,将输入特征图划分为多个子特征图,然后分别对每个子特征图进行卷积操作,并在卷积操作的过程中引入通道注意力。最后,将这些卷积得到的子特征图进行合并,得到最终的输出特征图。

ShuffleNetV2网络:

ShuffleNetV2是一种轻量级的神经网络模型,它是ShuffleNetV1的改进版本。ShuffleNetV2主要采用了两种技术通道分离组卷积。通道分离是指将输入的通道分成两个部分,分别进行不同的计算,然后再将它们合并在一起。这种方法可以减少计算量,提高模型的效率。组卷积是指将卷积操作分成多个小组,每个小组只处理一部分通道,然后再将它们合并在一起。这种方法可以减少参数量,提高模型的泛化能力。


🚀2.添加位置

本文的改进是基于YOLOv5-6.0版本,关于其网络结构具体如下图所示:

本文的改进是在主干网络中添加ECA注意力机制更换主干网络之ShuffleNetV2,具体添加位置如下图所示:

所以,本节课改进后的网络结构图具体如下图所示:


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:common.py文件修改

步骤2:yolo.py文件修改

步骤3:创建自定义yaml文件

步骤4:修改自定义yaml文件

步骤5:验证是否加入成功

步骤6:修改默认参数


🚀4.改进方法

💥💥步骤1:common.py文件修改

common.py中添加ECA注意力机制模块ShuffleNetV2模块,所要添加模块的代码如下所示,将其复制粘贴到common.py文件末尾的位置。

ECA注意力机制代码:

# ECA
class ECA(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, c1, c2, k_size=3):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        # Multi-scale information fusion
        y = self.sigmoid(y)
        return x * y.expand_as(x)

ShuffleNetV2模块代码:

# 更换主干网络之shuffleNetV2
def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    # reshape
    x = x.view(batchsize, groups,
               channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)
    return x
class CBRM(nn.Module):
    def __init__(self, c1, c2):
        super(CBRM, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    def forward(self, x):
        return self.maxpool(self.conv(x))

class ShuffleNetV2(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(ShuffleNetV2, self).__init__()
        if not (1 <= stride <= 2):
            raise ValueError('illegal stride value')
        self.stride = stride
        branch_features = ch_out // 2
        assert (self.stride != 1) or (ch_in == branch_features << 1)
        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(ch_in),

                nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )
        self.branch2 = nn.Sequential(
            nn.Conv2d(ch_in if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )
    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)  # 按照维度1进行split
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        out = channel_shuffle(out, 2)
        return out

💥💥步骤2:yolo.py文件修改

首先在yolo.py文件中找到parse_model函数这一行,加入ECACBAMShuffleNetV2。具体如下图所示:

💥💥步骤3:创建自定义yaml文件

models文件夹中复制yolov5s.yaml,粘贴并重命名为yolov5s_ECA_ShuffleNetV2.yaml具体如下图所示:

💥💥步骤4:修改自定义yaml文件

本步骤是修改yolov5s_ECA_ShuffleNetV2.yaml,根据改进后的网络结构图进行修改。

由下面这张图可知,当添加ECA注意力机制和更换主干网络之ShuffleNetV2之后,后面的层数会发生相应的变化,需要修改相关参数。

备注:层数从0开始计算,比如第0层、第1层、第2层......🍉 🍓 🍑 🍈 🍌 🍐  

综上所述,修改后的完整yaml文件如下所示:

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  # Shuffle_Block: [out, stride]
  [[ -1, 1, CBRM, [ 32 ] ], # 0-P2/4
   [ -1, 1, ShuffleNetV2, [ 128, 2 ] ],  # 1-P3/8
   [ -1, 3, ShuffleNetV2, [ 128, 1 ] ],  # 2
   [ -1, 1, ShuffleNetV2, [ 256, 2 ] ],  # 3-P4/16
   [ -1, 7, ShuffleNetV2, [ 256, 1 ] ],  # 4
   [ -1, 1, ShuffleNetV2, [ 512, 2 ] ],  # 5-P5/32
   [ -1, 3, ShuffleNetV2, [ 512, 1 ] ],  # 6
   [-1, 1, ECA, [512]],  # 7
   [-1, 1, SPPF, [1024, 5]],  # 8
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 12

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 2], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 16 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 13], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 19 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 9], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 22 (P5/32-large)

   [[16, 19, 22], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

💥💥步骤5:验证是否加入成功

yolo.py文件里,将配置改为我们刚才自定义的yolov5s_ECA_ShuffleNetV2.yaml

修改1,位置位于yolo.py文件165行左右,具体如图所示:

修改2,位置位于yolo.py文件363行左右,具体如下图所示:

配置完毕之后,点击“运行”,结果如下图所示:

由运行结果可知,与我们前面更改后的网络结构图相一致,证明添加成功了!✅ 

说明:由上图可以看出,添加ECA注意力机制和更换主干网络之ShuffleNetV2之后,参数量大大减少,所以,该种改进方式适合于轻量化部署。

💥💥步骤6:修改默认参数

train.py文件中找到parse_opt函数,然后将第二行 '--cfg的default改为 'models / yolov5s_ECA_ShuffleNetV2.yaml',然后就可以开始进行训练了。🎈🎈🎈 

结束语:关于更多YOLOv5学习知识,可参考专栏:《YOLOv5:从入门到实战》🍉 🍓 🍑 🍈 🍌 🍐

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1280701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分享77个焦点幻灯JS特效,总有一款适合您

分享77个焦点幻灯JS特效&#xff0c;总有一款适合您 77个焦点幻灯JS特效下载链接&#xff1a;百度网盘 请输入提取码 提取码&#xff1a;6666 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集整理更不易。知识付费甚欢喜&…

sourceTree的下载和安装

sourceTree的下载和安装 一、概述 SourceTree 是一款免费的 Git 和 Hg 客户端管理工具&#xff0c;支持 Git 项目的创建、克隆、提交、push、pull 和合并等操作。它拥有一个精美简洁的界面&#xff0c;大大简化了开发者与代码库之间的 Git 操作方式&#xff0c;这对于不熟悉 …

WebGL笔记:矩阵缩放的数学原理和实现

矩阵缩放的数学原理 和平移一样&#xff0c;以同样的原理&#xff0c;也可以理解缩放矩阵让向量OA基于原点进行缩放 x方向上缩放&#xff1a;sxy方向上缩放&#xff1a;syz方向上缩放&#xff1a;sz 最终得到向量OB 矩阵缩放的应用 比如我要让顶点在x轴向缩放2&#xff0c;y轴…

SCAU:分期还款(加强版)

分期还款(加强版) Time Limit:1000MS Memory Limit:65535K 题型: 编程题 语言: G;GCC;VC 描述 从银行贷款金额为d&#xff0c;准备每月还款额为p&#xff0c;月利率为r。请编写程序输入这三个数值&#xff0c;计算并输出多少个月能够还清贷款&#xff0c;输出时保留1位小…

java学习part32StringBuffer和StringBuilder

Java中的值传递和引用传递&#xff08;详解&#xff09; - 知乎 (zhihu.com) 146-常用类与基础API-StringBuffer与StringBuilder的源码分析、常用方法_哔哩哔哩_bilibili 1. 2.扩容机制 不够用&#xff1a;长度为 原长度*22&#xff1b;如果还不够&#xff0c;那么就扩容到目…

C++笔试训练day_1

文章目录 选择题编程题 选择题 编程题 #include <iostream> #include <algorithm> #include <vector>using namespace std;int main() {int n 0;cin >> n;vector<int> v;v.resize(3 * n);int x 0;for(int i 0; i < v.size(); i){cin >&…

ES-ELSER 如何在内网中离线导入ES官方的稀疏向量模型(国内网络环境下操作方法)

ES官方训练了稀疏向量模型&#xff0c;用来支持语义检索。&#xff08;目前该模型只支持英文&#xff09; 最好是以离线的方式安装。在线的方式&#xff0c;在国内下载也麻烦&#xff0c;下载速度也慢。还不如用离线的方式。对于一般的生产环境&#xff0c;基本上也是网络隔离的…

初识Linux:保姆级教学,让你一秒记住Linux中的常用指令!

文章目录 前言一、LInux的背景及发展史二、Linux下的基本指令1、ls指令2、pwd指令3、cd指令4、touch指令5、mkdir指令&#xff08;重要&#xff09;6、tree指令7、rmdir指令和rm指令&#xff08;重要&#xff09;8、man指令&#xff08;重要&#xff09;9、cp指令&#xff08;重…

[进程控制]模拟实现命令行解释器shell

文章目录 1.字符串切割函数2.chdir()接口3.模拟实现shell 1.字符串切割函数 2.chdir()接口 3.模拟实现shell 模拟实现的shell下删除: ctrlbackspace模拟实现下table/上下左右箭头无法使用[demo] #include <stdio.h> #include <stdlib.h> #include <string.h&g…

nodejs介绍

nodejs官网支持的各种库api https://nodejs.org/docs/latest-v21.x/api/http.html nodejs包括vp8引擎和内置的基本库如fs,path,http,querystring等&#xff0c;也可以用npm按转第三方库 npm是nodejs环境的包管理工具&#xff0c;可以为这个环境安装卸载各种包。 npm install pk…

Git Bash环境下用perl脚本获取uuid值

在Linux环境下&#xff0c;比如在ubuntu就直接有uuidgen命令直接获取uuid值。在Windows环境下常用的git bash中没有对应的命令&#xff0c;略有不便。这里用脚本写一个uuidgen&#xff0c;模拟Linux环境下的uuidgen命令。 #! /usr/bin/perl use v5.14; use Win32;sub uuidGen {…

frp 配置内网访问

frp介绍 frp 是一个开源、简洁易用、高性能的内网穿透软件&#xff0c;支持 tcp, udp, http, https 等协议。frp 项目官网是 https://github.com/fatedier/frp 下载地址&#xff1a; https://github.com/fatedier/frp/releases frp工作原理 服务端运行&#xff0c;监听一个…

坚鹏:中国工商银行数字化时代银行厅堂客户体验设计培训

中国工商银行围绕“数字生态、数字资产、数字技术、数字基建、数字基因”五维布局&#xff0c;深入推进数字化转型&#xff0c;加快形成体系化、生态化实施路径&#xff0c;促进科技与业务加速融合&#xff0c;以“数字工行”建设推动“GBC”&#xff08;政务、企业、个人&…

目标检测算法改进系列之添加SCConv空间和通道重构卷积

SCConv-空间和通道重构卷积 SCConv&#xff08;空间和通道重构卷积&#xff09;的高效卷积模块&#xff0c;以减少卷积神经网络&#xff08;CNN&#xff09;中的空间和通道冗余。SCConv旨在通过优化特征提取过程&#xff0c;减少计算资源消耗并提高网络性能。该模块包括两个单…

【小沐学Python】Python实现Web服务器(Flask+celery,生产者-消费者)

文章目录 1、简介2、安装和下载2.1 flask2.2 celery2.3 redis 3、功能开发3.1 创建异步任务的方法3.1.1 使用默认的参数3.1.2 指定相关参数3.1.3 自定义Task基类 3.2 调用异步任务的方法3.2.1 app.send_task3.2.2 Task.delay3.2.3 Task.apply_async 3.3 获取任务结果和状态 4、…

Spring cloud - gateway

什么是Spring Cloud Gateway 先去看下官网的解释&#xff1a; This project provides an API Gateway built on top of the Spring Ecosystem, including: Spring 6, Spring Boot 3 and Project Reactor. Spring Cloud Gateway aims to provide a simple, yet effective way t…

javaweb mybatis(手动jar包)

基础&#xff1a;https://blog.csdn.net/qq_67832732/article/details/134764134 条件查询 在映射文件的SQL配置中配置参数 使用parameterType来指定参数类型 使用#{参数名}来接收参数的值 parameterType"string" 表示sql语句需要一个参数&#xff0c;类型为字符…

计算UDP报文CRC校验的总结

概述 因公司项目需求&#xff0c;遇到需要发送带UDP/IP头数据包的功能&#xff0c;经过多次尝试顺利完成&#xff0c;博文记录以备忘。 环境信息 操作系统 ARM64平台的中标麒麟Kylin V10 工具 tcpdump、wireshark、vscode 原理 请查看大佬的博文 UDP伪包头定义&#x…

2023年【G1工业锅炉司炉】考试试题及G1工业锅炉司炉模拟考试题库

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2023年【G1工业锅炉司炉】考试试题及G1工业锅炉司炉模拟考试题库&#xff0c;包含G1工业锅炉司炉考试试题答案和解析及G1工业锅炉司炉模拟考试题库练习。安全生产模拟考试一点通结合国家G1工业锅炉司炉考试最新大纲及…

【开源】基于Vue和SpringBoot的校园二手交易系统

项目编号&#xff1a; S 009 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S009&#xff0c;文末获取源码。} 项目编号&#xff1a;S009&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 二手商品档案管理模…