YOLOv8 | 有效涨点,添加GAM注意力机制,使用Wise-IoU有效提升目标检测效果(附报错解决技巧,全网独家)

news2025/1/15 6:51:28

 目录

摘要

基本原理

通道注意力机制

空间注意力机制

GAM代码实现 

Wise-IoU 

WIoU代码实现

yaml文件编写

完整代码分享(含多种注意力机制)


摘要

人们已经研究了各种注意力机制来提高各种计算机视觉任务的性能。然而,现有方法忽视了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局注意力机制,通过减少信息减少和放大全局交互表示来提高深度神经网络的性能。引入了具有多层感知器的 3D 排列,用于通道注意以及卷积空间注意子模块。在 CIFAR-100 和 ImageNet-1K 上对所提出的图像分类任务机制的评估表明,我们的方法稳定优于最近使用 ResNet 和轻量级 MobileNet 的几种注意力机制。

基本原理

目标的设计是一种减少信息缩减并放大全局维度交互特征的机制。我们采用 CBAM 的顺序通道空间注意力机制并重新设计子模块。整个过程如图 所示。

GAM结构图
通道注意力机制

通道注意力子模块使用 3D 排列来保留三个维度的信息。然后,它使用两层 MLP(多层感知器)放大跨维度通道空间依赖性。 (MLP是一种编码器-解码器结构,其缩减比为r,与BAM相同。)通道注意子模块如图所示。 

通道注意力子模块
空间注意力机制

在空间注意力子模块中,为了关注空间信息,我们使用两个卷积层进行空间信息融合。我们还使用与 BAM 相同的通道注意子模块的缩减率 r。同时,最大池化会减少信息并产生负面影响。我们删除池化以进一步保留特征图。因此,空间注意力模块有时会显着增加参数的数量。为了防止参数显着增加,我们在 ResNet50 中采用带有通道洗牌的组卷积。没有组卷积的空间注意力子模块如图所示。 

空间注意力子模块
GAM代码实现 
class GAM_Attention(nn.Module):
    def __init__(self, c1, c2, group=True, rate=4):
        super(GAM_Attention, self).__init__()

        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )

        self.spatial_attention = nn.Sequential(

            nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(c2)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        # x_channel_att=channel_shuffle(x_channel_att,4) #last shuffle
        x = x * x_channel_att

        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffle
        out = x * x_spatial_att
        # out=channel_shuffle(out,4) #last shuffle
        return out

以上代码添加在 ./ultralytics/nn/modules/conv.py 中

Wise-IoU 

Yolov7提出的损失函数是GIoU(Generalized Intersection over Union),能在更广义的层面上计算IoU(Intersection over Union),但是当两个预测框完全重合时,不能反映出实际情况,此时GIoU就要退化为IoU,并且GIoU对每个预测框与真实框均要计算最小外接框,故损失函数计算及收敛速度受到限制。
为了弥补这种遗憾,改进的网络中使用了WIoU(Wise-IoU)作为损失函数。WIoU v3作为边界框回归损失,包含一种动态非单调机制,并设计了一种合理的梯度增益分配,该策略减少了极端样本中出现的大梯度或有害梯度。该损失方法计算更多地关注普通质量的样本,进而提高网络模型的泛化能力和整体性能。

虽然几种主流损失函数都采用静态聚焦机制,但WIoU不仅考虑了方位角、质心距离和重叠面积,还引入了动态非单调聚焦机制。 WIoU应用合理的梯度增益分配策略来评估锚框的质量。WIoU有三个版本。 WIoU v1 设计了基于注意力的预测框损失,WIoU v2 和 WIoU v3 添加了聚焦系数。

wiou原理图

最小的包围盒(绿色)和中心点的连接(红色),其中并集的面积为 Su = wh + wgthgt − WiHi .

WIoU代码实现
def WIoU(cls, pred, target, self=None):
        self = self if self else cls(pred, target)
        dist = torch.exp(self.l2_center / self.l2_box.detach())
        return self._scaled_loss(dist * self.iou)

 下面的代码替换loss.py的class BboxLoss

class BboxLoss(nn.Module):

    def __init__(self, reg_max, use_dfl=False):
        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
        super().__init__()
        self.reg_max = reg_max
        self.use_dfl = use_dfl

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        loss,iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False,type_='WIoU')
        loss_iou=loss.sum()/target_scores_sum

        # DFL loss
        if self.use_dfl:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
            loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl
yaml文件编写
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 3, GAM_Attention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13
  #- [-1, 1, GAM_Attention, [512,512]]

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)
  #- [-1, 1, GAM_Attention, [256,256]]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  #- [-1, 1, GAM_Attention, [512,512]]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)
  #- [-1, 1, GAM_Attention, [1024,1024]]

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)
完整代码分享(含多种注意力机制)

内涵SA,CBAM,GAM,ECA等多种注意力机制

链接: https://pan.baidu.com/s/1T9bVifTPCRMv2t7eREsuEw?pwd=nbrt 提取码: nbrt 

报错解决办法

YOLOv8 | 添加注意力机制报错KeyError:已解决,详细步骤-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1519486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SQL Server】实验六 数据安全性

1 实验目的 掌握用户管理的基本方法,包括创建用户、删除用户和设置用户密码。掌握用户授权和回收权限的基本方法。掌握系统级权限和对象级权限的授权和回收方法掌握角色的使用方法 2 实验内容 2.1 掌握用户管理的基本使用方法 创建用户(带密码&#…

vue3项目随笔1

1,Eslint Prettier 报错情况: 解决办法: (1)下载Prettier - code formatter (2)配置setting.json文件 文件 -> 首选项 -> 设置 -> 用户 -> Eslint "editor.defaultFormatter":…

C语言函数—递归理解和练习

练习: 编写函数不允许创建临时变量,求字符串的长度。 我们看到这道题,第一个想到的是不是strlen int main() {char[] "bit";//[b][i][t][\0]//里面一共4个字符(包括结尾的、0)但是我们的strlen函数并不会计…

谷歌网络营销方案有几种?​

谷歌作为海外的头部工具,本身其实就有多种工具可以供你使用,在这里说说谷歌那些工具 Google My Business,对于小企业或者本地服务来说,把自己的业务信息优化并完善在Google My Business上是个不错的选择。这样当人们在附近搜索相…

可视化场景(4):财务场景,公司经营的晴雨表。

在财务场景中,可视化大屏具有以下8个应用价值: 销售和收入分析 可视化大屏可以展示销售额、收入来源、销售渠道等数据,帮助财务团队分析销售趋势和收入结构,发现潜在的增长机会和问题。 成本和费用管理 可视化大屏可以显示成本…

蓝桥杯2022年第十三届省赛真题-裁纸刀

443 对于m行n列 次数 4 m - 1 (n-1)*m 其中4是裁掉边缘;行需要裁m-1次;每个小长条需要裁n-1次,一共有m个小长条

MFMailComposeViewController 发送邮件

通过 MFMailComposeViewController 发送邮件,需预先登录邮箱账号的情况下; 具体实现与配置参数请参考如下: 首先,引入 MFMailComposeViewController 库 #import <MessageUI/MessageUI.h> 其次,实现相关 api 方法 if ([MFMailComposeViewController canSendMail]) {MFM…

通过spring boot/redis/aspect 防止表单重复提交【防抖】

一、啥是防抖 所谓防抖&#xff0c;一是防用户手抖&#xff0c;二是防网络抖动。在Web系统中&#xff0c;表单提交是一个非常常见的功能&#xff0c;如果不加控制&#xff0c;容易因为用户的误操作或网络延迟导致同一请求被发送多次&#xff0c;进而生成重复的数据记录。要针…

解决ubuntu 22.04新内核6.5.0-15无法编译NVIDIA显卡驱动

这里的新内核应该包括6.5.*系列的 文章目录 遇到的问题&#xff1a; 遇到的问题&#xff1a; 今天我在安装NVIDIA显卡驱动发现了一个问题&#xff0c;主要日志如下所示&#xff1a; make[3]: *** [scripts/Makefile.build:251: /tmp/selfgz1310041/NVIDIA-Linux-x86_64-550.5…

综合利用Cisco Packet Tracer模拟器配置园区网

1. 内容 1.在课室交换机中创建各个课室的VLAN&#xff0c;并将1-20端口平均分配给各个课室。 2.使用课室交换机的每个端口只能接入一台计算机&#xff0c;发现违规就丢弃未定义地址的包。3.网络内部使用DHCP分配各课室的IP地址&#xff0c;在课室交换机按照第一题划分的VLAN地…

蜡烛图K线图采用PictureBox控件绘制是实现量化交易的第一步非python量化

用vb6.0开发的量化交易软件 VB6量化交易软件的演示视频演示如上 股票软件中的蜡烛图是非常重要的一个东西&#xff0c;这里用VB6.0自带的Picture1控件的Line方法就可以实现绘制。 关于PictureBox 中的line 用法 msdn 上的说明为如下所示 object.Line [Step] …

C#使用迭代算法计算斐波那契数列通项

目录 1.斐波纳契数列 2.迭代一次产生1个新的通项 3.迭代一次产生2个新的通项 1.斐波纳契数列 斐波纳契数列的定义是&#xff0c;它的第一项和第二项均为1&#xff0c;以后各项都为前两项之和。 公式如下&#xff1a; F(n) F(n-1) F(n-2) 其中&#xff0c;F(1) 0,…

CTP-API开发系列之十:v6.7.0-Python版封装(Windows/Linux)(附源码)

CTP-API开发系列之十&#xff1a;v6.7.0-Python版封装&#xff08;Windows/Linux&#xff09;&#xff08;附源码&#xff09; CTP-API开发系列之十&#xff1a;v6.7.0-Python版封装&#xff08;Windows/Linux&#xff09;&#xff08;附源码&#xff09;资源获取准备工作Windo…

实验2 芯片测试算法设计

一、【实验目的】 &#xff08;1&#xff09;理解分治策略的设计思想&#xff1b; &#xff08;2&#xff09;熟悉将伪码转换为可运行的程序的方法&#xff1b; &#xff08;3&#xff09;能够根据算法的要求设计具体的实例。 二、【实验内容】 有n片芯片&#xff0c;其中好芯片…

蓝桥杯每日一题:血色先锋队

今天浅浅复习巩固一下bfs 答案&#xff1a; #include<iostream> #include<algorithm> #include<cstring>using namespace std; typedef pair<int,int> PII;const int N510; int n,m,a,b; int dist[N][N]; PII q[N*N]; int hh0,tt-1;int dx[]{1,0,-1,…

蓝桥杯[OJ 1621]挑选子串-CPP-双指针

目录 一、题目描述&#xff1a; 二、整体思路&#xff1a; 三、代码&#xff1a; 一、题目描述&#xff1a; 二、整体思路&#xff1a; 要找子串&#xff0c;则必须找头找尾&#xff0c;找头可以遍历连续字串&#xff0c;找尾则是要从头的基础上往后遍历&#xff0c;可以设头…

【spring】@Import 注解学习

Import 介绍 Import 是 Spring 框架中的一个注解&#xff0c;用于导入配置类或组件。它可以将一个或多个配置类或组件导入到当前的配置类或组件中&#xff0c;从而实现配置的复用和组合。 在Spring Boot应用中&#xff0c;Import注解可以帮助我们更加灵活地组织和管理配置类。…

(学习日记)2024.03.09:UCOSIII第十一节:就绪列表

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

正点原子精英版TFTLCD代码移植

&#xff08;1&#xff09;将lcd.c和lcd.h加入到HEADWARE文件中 &#xff08;2&#xff09;将lcd.c加入到环境中 选择lcd.c即可。 &#xff08;3&#xff09;在FWLib中添加stm32f10x_fsmc.c

Spring Boot整合canal实现数据一致性解决方案解析-部署+实战

&#x1f3f7;️个人主页&#xff1a;牵着猫散步的鼠鼠 &#x1f3f7;️系列专栏&#xff1a;Java全栈-专栏 &#x1f3f7;️个人学习笔记&#xff0c;若有缺误&#xff0c;欢迎评论区指正 目录 1.前言 2.canal部署安装 3.Spring Boot整合canal 3.1数据库与缓存一致性问题…