YOLOv5改进 | 注意力机制 | 添加全局注意力机制 GcNet【附代码+小白必备】

news2025/1/11 2:40:49

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

非局部网络通过将特定于查询的全局上下文聚合到每个查询位置,为捕获长距离依赖关系提供了一种开创性的方法。然而,通过实验证明,非局部网络建模的全局上下文在图像内不同的查询位置几乎是相同的。因此,利用这一发现创建了一个简化的网络。在本文中,给大家带来的教程是将原来的的网络替换添加GcNet。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址 YOLOv5改进+入门——持续更新各种有效涨点方法 点击即可跳转

目录

1.原理

2. GcNet代码实现

2.1 将GcNet添加到YOLOv5中

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3. 完整代码分享

4.GFLOPs

5. 进阶

6. 总结


1.原理

论文地址:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond点击即可跳转

官方代码:官方代码仓库——点击即可跳转

GCNet(Global Context Network)是一种用于计算机视觉任务的注意力机制模型。该模型旨在提高深度神经网络对全局上下文信息的理解和利用,以改善其在各种视觉任务中的性能。

GCNet 主要解决的问题是,传统的卷积神经网络(CNN)在处理图像时,往往只关注局部信息,而忽略了图像的全局上下文信息。GCNet 引入了全局上下文注意力机制,允许网络在处理图像时动态地调整不同位置的权重,以更好地捕捉全局信息。

GCNet 的主要组成部分:

  1. 全局上下文模块:GCNet 包括一个全局上下文模块,用于从整个图像中提取全局信息。这个模块通常由全局平均池化层(Global Average Pooling)组成,用于将整个特征图压缩成一个全局特征向量。

  2. 通道注意力机制:GCNet 使用通道注意力机制来动态地调整特征图中每个通道的权重,以更好地捕获图像中不同通道的重要性。这有助于网络更加聚焦于对解决当前任务最重要的特征。

  3. 空间注意力机制:除了通道注意力,GCNet 还引入了空间注意力机制,以考虑不同位置之间的关系。这个机制通常通过使用卷积操作来实现,以便网络可以学习到图像中不同位置之间的依赖关系。

  4. 反馈机制:GCNet 通常具有反馈机制,允许网络根据任务的需求动态地调整注意力权重。这种反馈机制通常通过在训练过程中引入反向传播来实现,网络可以根据任务的损失来调整注意力权重。

总的来说,GCNet 通过引入全局上下文注意力机制,允许网络更好地理解和利用图像的全局信息,从而提高了在各种计算机视觉任务中的性能。它的核心思想是通过动态地调整注意力权重来关注图像中最相关的信息,从而提高了网络的表现。

2. GcNet代码实现

2.1 将GcNet添加到YOLOv5中

关键步骤一: 将下面代码粘贴到/projects/yolov5-6.1/models/common.py文件中

img

import torch
from torch import nn

class ContextBlock(nn.Module):
    def __init__(self,inplanes,ratio,pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None


    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
        return out

GCNet的主要流程可以概括为以下几个步骤:

  1. 输入图像的特征提取

    • 首先,将输入图像通过一个预训练的卷积神经网络(如ResNet、VGG等)进行特征提取,得到一个特征图。

  2. 全局上下文模块

    • 将特征图输入到全局上下文模块中。

    • 在全局上下文模块中,使用全局平均池化(Global Average Pooling)对整个特征图进行操作,得到一个全局上下文特征向量。

  3. 通道注意力机制

    • 将全局上下文特征向量与原始特征图结合起来。

    • 对结合后的特征进行通道注意力机制的处理,动态地调整特征图中每个通道的权重,以便网络能够自适应地关注对解决当前任务最重要的特征通道。

  4. 空间注意力机制

    • 在通道注意力机制之后,还可以引入空间注意力机制。

    • 通过卷积操作,在特征图的空间维度上学习到不同位置之间的依赖关系,从而更好地捕捉图像中的结构信息。

  5. 特征融合和输出

    • 将经过通道注意力和空间注意力机制处理后的特征图输入到后续的网络层中,可能是全连接层或者其他类型的层。

    • 最终,通过输出层生成网络的最终预测结果,如分类标签、图像分割等。

在整个流程中,GCNet通过全局上下文模块和注意力机制,使网络能够更好地理解和利用图像中的全局信息,并自适应地关注对解决当前任务最重要的特征通道和位置信息,从而提高了网络在各种图像处理任务中的性能和效果。

2.2 新增yaml文件

关键步骤二在下/projects/yolov5-6.1/models下新建文件 yolov5_GcNet.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, ContextBlock, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 15], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

温馨提示:本文只是对yolov5l基础上添加swin模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
 
# YOLOv5s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
 
# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
 
# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
 
# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple

2.3 注册模块

关键步骤:在yolo.py中注册, 大概在260行左右添加 ‘ContextBlock’

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_GcNet.yaml的路径

建议大家写绝对路径,确保一定能找到

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1Dl__hrt7bBe_UnELz7_o4g?pwd=d1ym

提取码: d1ym 

4.GFLOPs

关于GFLOPs的计算方式可以查看百面算法工程师 | 卷积基础知识——Convolution

未改进的GFLOPs

改进后的GFLOPs  

5. 进阶

你能在不同的位置添加全局注意力机制吗?但这可能会产生非常大的计算量

6. 总结

GCNet(Global Context Network)是一种引入全局上下文信息和注意力机制的图像处理模型。其流程首先通过全局上下文模块提取整个图像的全局信息,然后通过通道注意力机制动态调整特征图中每个通道的权重,使网络能够自适应地关注对解决当前任务最重要的特征通道。接着,GCNet引入空间注意力机制,学习不同位置之间的依赖关系,以更好地捕捉图像中的结构信息。通过全局上下文模块和注意力机制的综合利用,GCNet使网络能够更充分地理解和利用图像中的全局信息,从而提高了在各种图像处理任务中的性能和效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1707504.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android 13 高通设备热点低功耗模式

需求: Android设备开启热点,使Iphone设备连接,自动开启低数据模式 低数据模式: 低数据模式是一种在移动网络或Wi-Fi环境下,通过限制应用程序的数据使用、降低数据传输速率或禁用某些后台操作来减少数据流量消耗的优化模式。 这种模式主要用于节省数据流量费用,特别是…

Github Page 部署失败

添加 .gitmodules 文件 [submodule "themes/ayer"]path themes/ayerurl https://github.com/Shen-Yu/hexo-theme-ayer.git 添加 .nojekyll 文件

使用 Orange Pi AIpro开发板基于 YOLOv8 进行USB 摄像头实时目标检测

文章大纲 简介算力指标与概念香橙派 AIpro NPU 纸面算力直观了解 手把手教你开机与基本配置开机存储挂载设置风扇设置 使用 Orange Pi AIpro进行YOLOv8 目标检测Pytorch pt 格式直接推理NCNN 格式推理 是否可以使用Orange Pi AIpro 的 NPU 进行推理 呢?模型开发流程…

vue 微信公众号定时发送模版消息

目录 第一步:公众号设置 网页授权第二步:引导用户去授权页面并获取code第三步:通过code换取网页授权access_token&openid第四步:后端处理绑定用户和发送消息 相关文档链接: 1、微信开发文档 2、订阅号/服务号/企业…

AI生成视频解决方案,降低成本,提高效率

传统的视频制作方式往往受限于高昂的成本、复杂的拍摄流程以及硬件设备的限制,为了解决这些问题,美摄科技凭借领先的AI技术,推出了全新的AI生成视频解决方案,为企业带来前所未有的视觉创新体验。 一、超越想象的AI视频生成 美摄…

【计算机视觉(4)】

基于Python的OpenCV基础入门——色彩空间转换 色彩空间简介HSV色彩空间GRAY色彩空间色彩空间转换 色彩空间转换代码实现: 色彩空间简介 色彩空间是人们为了表示不同频率的光线的色彩而建立的多种色彩模型。常见的色彩空间有RGB、HSV、HIS、YCrCb、YUV、GRAY,其中最…

Sora,数据驱动的物理引擎

文生视频技术 Text-to-Video 近日,Open AI发布文生视频模型Sora,能够生成一分钟高保真视频。人们惊呼:“真实世界将不再存在。” Open AI自称Sora是“世界模拟器”,让“一句话生成视频”的AI技术向上突破了一大截,引…

数据恢复与取证软件: WinHex 与 X-Ways Forensics 不同许可证功能区别

天津鸿萌科贸发展有限公司从事数据安全业务20余年,在数据恢复、数据取证、数据备份等领域有丰富的案例经验、专业技术及良好的行业口碑。同时,公司面向取证机构及数据恢复公司,提供数据恢复实验室建设方案,包含数据恢复硬件设备及…

外贸仓库管理软件:海外仓效率大幅度提升、避免劳动力积压

随着外贸业务的不断发展,如何高效管理外贸仓库,确保货物顺利流转,订单顺利处理,就变得非常重要。 现在通常的解决方案都是通过引入外贸仓库管理软件,也就是我们常说的海外仓WMS系统来解决。 今天我们就系统的探讨一下…

闲话 .NET(3):.NET Framework 的缺点

前言 2016 年,微软正式推出 .NET Core 1.0,并在 2019 年全面停止 .NET Framework 的更新。 .NET Core 并不是 .NET Framework 的升级版,而是一个从头开始开发的全新平台,一个跟 .NET Framework 截然不同的开源技术框架。 微软为…

Paddle 稀疏计算 使用指南

Paddle 稀疏计算 使用指南 1. 稀疏格式介绍 1.1 稀疏格式介绍 稀疏矩阵是一种特殊的矩阵,其中绝大多数元素为0。与密集矩阵相比,稀疏矩阵可以节省大量存储空间,并提高计算效率。 例如,一个5x5的矩阵中只有3个非零元素: impor…

CTFHUB技能树——SSRF(一)

目录 一、SSRF(服务器端请求伪造) 漏洞产生原理: 漏洞一般存在于 产生SSRF漏洞的函数(PHP): 发现SSRF漏洞时: SSRF危害: SSRF漏洞利用手段: SSRF绕过方法: 二、CTFHUB技能树 SSRF 1.Ht…

线上服务突然变慢,卡了很久都出不来

文章目录 0、架构1、现象2、查看服务器指标2.1 cpu负载不高2.2 内存指标2.3 硬盘指标2.4 错误日志2.5 大量的tcp连接为TIME_WAIT 3、总结 0、架构 nginx—>httpd—>postgres 单体服务 1、现象 进入页面非常慢。。。 2、查看服务器指标 2.1 cpu负载不高 如下图&…

vue3学习(三)

前言 继续接上一篇笔记,继续学习的vue的组件化知识,可能需要分2个小节记录。前端大佬请忽略,也可以留下大家的鼓励,感恩! 一、理解组件化 二、组件化知识 1、先上知识点: 2、示例代码 App.vue (主页面) …

Captura完全免费的电脑录屏软件

一、简介 1、Captura 是一款免费开源的电脑录屏软件,允许用户捕捉电脑屏幕上的任意区域、窗口、甚至是全屏画面,并将这些画面录制为视频文件。这款软件具有多种功能,例如可以设置是否显示鼠标、记录鼠标点击、键盘按键、计时器以及声音等。此…

BookxNote Pro 宝藏 PDF 笔记软件

一、简介 1、BookxNote Pro 是一款专为电子书阅读和学习笔记设计的软件,支持多种电子书格式,如PDF和EPUB,能够帮助用户高效地管理和阅读电子书籍,同时具备强大的笔记功能,允许用户对书籍内容进行标注、摘录和思维导图绘…

解锁数据奥秘,SPSS for Mac/WIN助您智赢未来

在信息爆炸的时代,数据已成为推动社会进步和企业发展的核心动力。但如何将这些海量数据转化为有价值的洞见,却是摆在每一位决策者面前的难题。IBM SPSS Statistics,一款专业的统计分析软件,凭借其强大的功能和易用的界面&#xff…

机械产品3d模型网站让您的展示内容更加易于分享和传播

为助力企业3D产品演示网站获得更多曝光和展示特效,华锐视点3D云展平台提供强大的3D编辑引擎,以逼真的渲染效果,让您的模型展示更加生动逼真。让客户也能轻松操作的3D产品演示网站搭建编辑器,引领3D展示的新潮流。 3D产品演示网站搭…

总结 HTTP 协议的基本格式

一、HTTP 是什么 HTTP ( 全称为 " 超文本传输协议 ") 是一种应用非常广泛的 应用层协议 . HTTP 诞生与 1991 年 . 目前已经发展为最主流使用的一种应用层协议 . HTTP 协议目前有三个大版本: HTTP / 1 和 HTTP / 2 都是基于TCP 传输控制协议传输数据。最新版本的…

14.Redis之JAVASpring客户端

1.引入依赖 此时就会引入操作 redis 的依赖了~~ 2.yml配置 spring:redis:host: 127.0.0.1port: 8888 3.准备 前面使用 jedis,是通过 Jedis 对象里的各种方法来操作 redis 的.此处Spring 中则是通过 StringRedisTemplate 来操作 redis .最原始提供的类是 RedisTemplateStrin…