YOLO算法改进5【中阶改进篇】:添加SENet注意力机制

news2025/1/11 1:52:44

在这里插入图片描述

SE-Net 是 ImageNet 2017(ImageNet 收官赛)的冠军模型,是由WMW团队发布。具有复杂度低,参数少和计算量小的优点。且SENet 思路很简单,很容易扩展到已有网络结构如 Inception 和 ResNet 中。
我们可以看到,已经有很多工作在空间维度上来提升网络的性能。那么很自然想到,网络是否可以从其他层面来考虑去提升性能,比如考虑特征通道之间的关系?作者基于这一点并提出了Squeeze-and-Excitation Networks(简称SE-Net)。在该结构中,SqueezeExcitation是两个非常关键的操作,所以以此来命名。作者出发点是希望建立特征通道之间的相互依赖关系。并未引入一个新的空间维度来进行特征通道间的融合,而是采用了一种全新的“特征重标定”策略。具体来说,就是通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征。

一、不改变原网络深度的改进方法

在这里插入图片描述
首先是打开models/yolov5s.yaml文件,我们在backbone中的SPPF之前增加SENet。增添位置如下,是将backbone中第4个C3模块替换为SE_Block,如上图。需要注意的是通道数要匹配,SENet并不改变通道数,由于原C3的输出通道数为1024*0.5=512,所以我们这里的写的是1024,这里的1024是传入到上面我们定义的Class SE_Block(nn.Moudel)中的c2参数,c1参数是由上一层的输出通道数控制的。参考链接

1.添加SENet.yaml文件
添加至/models/文件中

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  conv1(3,32,k=6,s=2,p=2)
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4  conv2(32,64,k=3,s=2,p=1)
   [-1, 3, C3, [128]],  # C3_1 有Bottleneck
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8  conv3(64,128,k=3,s=2,p=1)
   [-1, 6, C3, [256]], # C3_2 Bottleneck重复两次
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16  conv4(128,256,k=3,s=2,p=1)
   [-1, 9, C3, [512]], # C3_3 Bottleneck重复三次 输出256通道
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32   Conv5(256,512,k=3,s=2,p=1)
   #[-1, 3, C3, [1024]],  # C3_4 Bottleneck重复1次  输出512通道
   [-1, 1, SE_Block, [1024]],  # 增加通道注意力机制 输出为512通道
   [-1, 1, SPPF, [1024, 5]],  # 9  每个都是K为5的池化
  ]

# YOLOAir v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, ShuffleAttention, [1024]], # 修改

   [[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

2.common配置
在models/common.py文件中增加以下代码
在这里插入图片描述

  • 上图是作者提出的SE模块的示意图。给定一个输入 x x x,其特征通道数为 c 1 c_1 c1,通过一系列卷积变换后得到一个特征通道数为 c 2 c_2 c2的特征。与传统的CNN不一样的是,接下来将通过三个操作来重标定前面得到的特征。
  • 首先是Squeeze操作,顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野,这一点在很多任务中都是非常有用的。
  • 其次是Excitation操作,它是一个类似于循环神经网络中门的机制。通过参数来为每个特征通道生成权重,其中参数被学习用来显式地建模特征通道间的相关性。
  • 最后是一个Reweight的操作,我们将Excitation的输出的权重看做是进过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定。
    ——————————————————————————————————————————
    图2 SE模块应用举例
  • 这里的注意力机制想法非常简单,即针对每一个 channel 进行池化处理,就得到了 channel
    个元素,通过两个全连接层,得到输出的这个向量。值得注意的是,第一个全连接层的节点个数等于 channel 个数的 1 4 \frac{1}{4} 41论文作者发现如果将第一个全连接层的节点个数替换成原来的 1 4 \frac{1}{4} 41,可以在参数数量适度增加的情况下提高准确性,而且并没有明显的延迟。),然后第二个全连接层的节点就和channel 保持一致。这个得到的输出就相当于对原始的特征矩阵的每个 channel 分析其重要程度,越重要的赋予越大的权重,越不重要的就赋予越小的权重。
  • 就拿上图来说,首先对四个通道进行平均池化得到四个值,然后经过两个全连接层之后得到通道权重的输出。等权重输出以后,则将对应通道的权重乘以原来的特征矩阵就得到了新的特征矩阵,以上便是SE模块的详细实现过程。
class SE_Block(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 平均池化
        self.fc = nn.Sequential(
            nn.Linear(c1, c2 // 16, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c2 // 16, c2, bias=False),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        # 添加注意力模块
        b, c, _, _ = x.size()  # 分别获取batch_size,channel
        y = self.avg_pool(x).view(b, c)  # y的shape为【batch_size, channels】
        y = self.fc(y).view(b, c, 1, 1)  # shape为【batch_size, channels, 1, 1】
        out = x * y.expand_as(x)  # shape 为【batch, channels,feature_w, feature_h】
        return out

3.yolo.py配置
找到 models/yolo.py 文件中 parse_model() 类,在列表中添加SE_Block,这样可以获得我们要传入的参数。

if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, SE_Block]:

二、改变原网络深度的改进方法

在这里插入图片描述
在这里插入图片描述
比如我要在第一个C3后面加一个SE。yaml的修改如下。接下来稍微麻烦一点了【需要你了解v5的每层结构】,由于我们在backbone中加入了一层,也就是相当于后面的网络与之前相比都往后移动了一层,那么在后面的Concat部分中融合的特征层的索引也会收到影响,因此我们需要的是修改Concat层的from参数。参考链接

1.添加SENet.yaml文件
添加至/models/文件中

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  conv1(3,32,k=6,s=2,p=2)
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4  conv2(32,64,k=3,s=2,p=1)
   [-1, 3, C3, [128]],  # C3_1 有Bottleneck
   [-1, 1, SE_Block, [128]],  # 增加通道注意力机制 输出为512通道
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8  conv3(64,128,k=3,s=2,p=1)
   [-1, 6, C3, [256]], # C3_2 Bottleneck重复两次
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16  conv4(128,256,k=3,s=2,p=1)
   [-1, 9, C3, [512]], # C3_3 Bottleneck重复三次 输出256通道
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32   Conv5(256,512,k=3,s=2,p=1)
   [-1, 3, C3, [1024]],  # C3_4 Bottleneck重复1次  输出512通道
   [-1, 1, SPPF, [1024, 5]],  # 9  每个都是K为5的池化
  ]
"""可以看到实际就是每个Concat也后面移动一层,因此yaml修改为一下。最终的Detect的from也需要修改。""
# YOLOAir v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],  # conv1(512,256,1,1)
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4  将C3_3与SPPF出来后的上采样拼接 拼接后的通道为512
   [-1, 3, C3, [512, False]],  # 13  conv(256,256,k=1,s=1)  没有残差边
 
   [-1, 1, Conv, [256, 1, 1]], # conv2(256,128,1,1)
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3  与C3_2拼接,输出256通道
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small) conv3(128,128,1,1)
 
   [-1, 1, Conv, [256, 3, 2]],# conv4(128,128,3,2,1)
   [[-1, 15], 1, Concat, [1]],  # cat head P4  拼接后256通道
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)  conv5(256,256,1,1)
 
   [-1, 1, Conv, [512, 3, 2]],# conv6(256,256,3,2,1)
   [[-1, 11], 1, Concat, [1]],  # cat head P5  拼接后是512
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
 
   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

2.common配置
在models/common.py文件中增加以下代码

class SE_Block(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 平均池化
        self.fc = nn.Sequential(
            nn.Linear(c1, c2 // 16, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c2 // 16, c2, bias=False),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        # 添加注意力模块
        b, c, _, _ = x.size()  # 分别获取batch_size,channel
        y = self.avg_pool(x).view(b, c)  # y的shape为【batch_size, channels】
        y = self.fc(y).view(b, c, 1, 1)  # shape为【batch_size, channels, 1, 1】
        out = x * y.expand_as(x)  # shape 为【batch, channels,feature_w, feature_h】
        return out

3.yolo.py配置
找到 models/yolo.py 文件中 parse_model() 类,在列表中添加SE_Block,这样可以获得我们要传入的参数。

if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, SE_Block]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1169003.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NLP 模型中的偏差和公平性检测

一、说明 近年来,自然语言处理 (NLP) 模型广受欢迎,彻底改变了我们与文本数据交互和分析的方式。这些基于深度学习技术的模型在广泛的应用中表现出了卓越的能力,从聊天机器人和语言翻译到情感分析和文本生成。然而&…

IntelliJ IDEA2023旗舰版和社区版下载安装教程(图解)

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

【C语法学习】14 - 流和文件

文章目录 1 流1.1 概念1.2 优点1.3 分类1.3.1 输入流和输出流1.3.2 标准流和文件流1.3.2.1 标准流1.3.2.2 文件流 1.3.3 文本流和二进制流 4 文件4.1 分类4.2 区别 1 流 1.1 概念 "流"的概念比较抽象,经典C语言书籍《C Primer Plus》中是这样描述的&…

两分钟搞定MySQL安装——极速mysql5.7安装教程

一、下载mysql mysql官网传送带: MySQL :: Download MySQL Community Server 选择好版本后直接下载即可,版本格式为zip格式。 二、安装mysql 1、解压zip安装包 ps:解压缩的路径里面不要出现中文哦! 2、设置配置文件 新建data…

字节10年经验,花3个月熬夜整理的超全面试真题(附答案)

软件测试工程师,和开发工程师相比起来,虽然前期可能不会太深,但是涉及的面还是比较广的。前期面试实习生或者一年左右的岗位,问的也主要是一些基础性的问题比较多。涉及的知识主要有MySQL数据库的使用、Linux操作系统的使用、软件…

一、修改Ubuntu的IP

1、配置虚拟机 三台虚拟机,分别为node1、node2、node3,内存分别为4G、2G、2G,现存最好为(>40G),如下: 2、修改主机名 分别打开三台虚拟机,root用户输入一下命令: no…

too many open files(打开的文件过多)解决办法

我们java服务报java.net.SocketException: 打开的文件过多 由于我们文件服务用的是minio 所以排查思路应该是进入文件服务器查看minio的pid ps -ef |grep minio lsof -p 11956 | wc -l 由此可以看出已经打开数为1950了 所以我们要提升打开文件数(openfile&#…

matlab矩阵的输入

一段时间不操作感觉有些模糊;下面复习一下; 矩阵的数字之间用空格分开,每一行以分号结尾; 空格多几个也是可以识别的; 数字也可以用逗号隔开,只要一行的结尾是分号即可; 空格多输入几个是可…

Rust编程基础核心之所有权(上)

1.什么是所有权? Rust 的核心功能(之一)是 所有权(ownership)。虽然该功能很容易解释,但它对语言的其他部分有着深刻的影响。 所有程序都必须管理其运行时使用计算机内存的方式。一些语言中具有垃圾回收机制&#x…

【李群李代数】【manif 】基于固定信标的2D机器人定位 (Error State Kalman Filter)...

demo演示 运行结果 我们考虑一个机器人在平面上被少量的准时地标或_信标 包围。 机器人以轴向速度和角速度的形式接收控制动作,并且能够测量信标相对于其自身参考系的位置。 机器人位姿 X 在 SE(2) 中,信标位置 b_k 在 R^2 中, | cos th -si…

配置Raspberry自动连接WIFI,在无法查看路由器的校园网情况下使用自己电脑热点

1、开启电脑热点,并共享电脑WLAN2 打开控制面板->网络和Internet->网络连接 选择自己的校园网,我这里是WLAN2,右键属性,如下操作: 如果没有看到 本地连接*10类似的图标 则按如下操作:winx键&#x…

国家统计局教育部各级各类学历教育学生情况数据爬取

教育部数据爬取 1、数据来源2、爬取目标3、网页分析4、爬取与解析5、如何使用Excel打开CSV1、数据来源 国家统计局:http://www.stats.gov.cn/sj/ 教育部:http://www.moe.gov.cn/jyb_sjzl/ 数据来源:国家统计局教育部文献教育统计数据2021年全国基本情况(各级各类学历教育学…

网络协议的基本概念

网络协议的基本概念 随处可见的协议 在计算机网络与信息通信领域里,人们经常提及“协议”一词。互联网中常用的具有代表性的协议有IP、TCP、HTTP等。 “计算机网络体系结构”将这些网络协议进行了系统归纳。TCP/IP就是IP、TCP、HTTP等协议的集合。现在&#xff0…

【MATLAB源码-第67期】基于麻雀搜索算法(SSA)的无人机三维地图路径规划,输出最短路径和适应度曲线。

操作环境: MATLAB 2022a 1、算法描述 ​麻雀搜索算法(Sparrow Search Algorithm, SSA)是一种新颖的元启发式优化算法,它受到麻雀社会行为的启发。这种算法通过模拟麻雀的食物搜索行为和逃避天敌的策略来解决优化问题。SSA通过模…

常用的网站扫描工具

破壳扫目录 7KB扫目录 safe3扫漏洞

怎样做好金融投资翻译

我们知道, 金融投资翻译所需的译文往往是会议文献、年终报表、信贷审批等重要企业金融资料,其准确性事关整个企业在今后一段时期内的发展战略与经营成效。尤其像年报,对于上市公司来说更是至关重要的。那么,怎样做好金融投资翻译&…

【C语言初学者周冲刺计划】5.1C语言知识点小总结

目录 1知识点一: 2知识点二: 3知识点三: 4代码: 5总结: 1知识点一: 1 C语言中要求对变量作强制定义的主要理由是( )。 便于确定类型和分配空间 2 【单选题】若有定义:int m7; float x…

利用稳定扩散快速修复图像

推荐Stable Diffusion自动纹理工具: DreamTexture.js自动纹理化开发包 什么是InPainting? 图像修复是人工智能研究的一个活跃领域,人工智能已经能够提出比大多数艺术家更好的修复效果。 这是一种生成图像的方式,其中缺失的部分已…

【音视频 | Ogg】libogg库详细介绍以及使用——附带libogg库解析.opus文件的C源码

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 🤣本文内容🤣&a…

RIP路由配置

RIP路由配置步骤与命令: 1.启用RIP路由:router rip 2.通告直连网络:network 直连网络 3.启用RIPv2版本:version 2 4.禁用自动汇总:no auto-summary 注意:静态路由通告远程网络,动态路由通告…