YOLOv5改进 | 主干网络 | 在backbone添加Swin-Transformer层

news2024/10/6 6:51:18

尽管Ultralytics 推出了最新版本的 YOLOv8 模型。但YOLOv5作为一个anchor base的目标检测的算法,YOLOv5可能比YOLOv8的效果更好。注意力机制是提高模型性能最热门的方法之一,本文给大家带来的教程是添加Swin-Transformer到backbone中。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。


专栏地址YOLOv5改进+入门——持续更新各种有效涨点方法 

目录

 1.原理

 2.Swin-Transformer代码

2.1 添加Swin-Transformer代码

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

 3.总结


 1.原理

论文地址:Swin-Transformer点击即可跳转

官方代码:Swin-Transformer官方代码仓库点击即可跳转

Swin-Transformer是MRA的作品,而MRA撑起了深度学习的半边天。

Swin-Transformer是2021年微软研究院发表在ICCV上的一篇文章,并且已经获得ICCV 2021 best paper 的荣誉称号。Swin Transformer网络是Transformer模型在视觉领域的又一次碰撞。该论文一经发表就已在多项视觉任务中霸榜。

Swin Transformer使用了类似卷积神经网络中的层次化构建方法(Hierarchical feature maps),比如特征图尺寸中有对图像下采样4倍的,8倍的以及16倍的,这样的backbone有助于在此基础上构建目标检测,实例分割等任务。而在之前的Vision Transformer中是一开始就直接下采样16倍,后面的特征图也是维持这个下采样率不变。

在Swin Transformer中使用了Windows Multi-Head Self-Attention(W-MSA)的概念,比如在下图的4倍下采样和8倍下采样中,将特征图划分成了多个不相交的区域(Window),并且Multi-Head Self-Attention只在每个窗口(Window)内进行。相对于Vision Transformer中直接对整个特征图进行Multi-Head Self-Attention,这样做的目的是能够减少计算量的,尤其是在浅层特征图很大的时候。这样做虽然减少了计算量但也会隔绝不同窗口之间的信息传递,所以在论文中作者又提出了 Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念,通过此方法能够让信息在相邻的窗口中进行传递。

因此,Swin-Transformer在计算效率上相对于传统的 Transformer 架构具有优势。Swin Transformer 是一种高效且性能优越的深度学习模型,适用于图像分类、目标检测等视觉任务,并且在处理大规模图像数据时表现出色。

注意:因为涉及代码较多,比较冗长,因此在此处不在放置完整代码,只放关键代码,完整代码可以查看文章末尾的内容。

2.Swin-Transformer代码

2.1 添加Swin-Transformer代码

关键步骤一:在\yolov5-6.1\models\common.py中添加下面代码

class SwinStage(nn.Module):

    def __init__(self, dim, c2, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False):
        super().__init__()
        assert dim==c2, r"no. in/out channel should be same"
        self.dim = dim
        self.depth = depth
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint
        self.shift_size = window_size // 2

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])


    def create_mask(self, x, H, W):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        # 拥有和feature map一样的通道排列顺序,方便后续window_partition
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).contiguous().view(B, H*W, C)
        attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]
        for blk in self.blocks:
            blk.H, blk.W = H, W
            if not torch.jit.is_scripting() and self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)

        x = x.view(B, H, W, C)
        x = x.permute(0, 3, 1, 2).contiguous()

        return x

Swin-Transformer模型处理流程主要包括以下几个步骤:

1. 图像分割和块划分:将输入图像分割成多个块,每个块都会被送入模型进行处理。

2. 块级特征提取:每个块经过块级特征提取阶段,使用局部注意力机制和全局注意力机制提取块级特征。

3. 块级特征整合:整合来自不同块的特征,通过跨块注意力机制实现特征的交互和整合。

4. 多层交叉处理:通过多层的交叉处理,增强模型对图像特征的表示能力。

5. 特征重组:将整合后的特征重新组合成全局特征表示。

6. 分类/回归:利用全局特征表示进行图像分类或其他任务,如对象检测或语义分割。

整个流程利用了局部和全局的位置信息,有效地捕获了图像中的不同位置和层次的信息,从而提高了模型在图像处理任务上的性能。

2.2 新增yaml文件

关键步骤二:在 /yolov5/models/ 下新建文件 yolov5_swin.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 1  # number of classes
# ch: 1   # no. input channel

depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple

anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  # input [b, 1, 640, 640]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2 [b, 64, 320, 320]
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4 [b, 128, 160, 160]
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8 [b, 256, 80, 80]
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16 [b, 512, 40, 40]
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32 [b, 1024, 20, 20]
   [-1, 3, C3, [1024]],
   [-1, 1, SwinStage, [1024, 2, 8, 4]], # [outputChannel, blockDepth, numHeaders, windowSize]
   [-1, 1, SPPF, [1024, 5]],  # 10
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 14

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 18 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 15], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 21 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 24 (P5/32-large)

   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

温馨提示:因为本文只是对yolov5n基础上添加swin模块,如果要对yolov5n/l/m/x进行添加则只需要修改对应的depth_multiple 和 width_multiple。


yolov5n/l/m/x对应的depth_multiple 和 width_multiple如下:

# YOLOv5n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple

# YOLOv5s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple

# YOLOv5l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
 
# YOLOv5m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
 
# YOLOv5x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple
2.3 注册模块

关键步骤三:在yolov5/models/yolo.py中注册,大概在250行左右添加 ‘SwinStage’

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_swin.yaml的路径,如下图所示

建议大家写绝对路径,确保一定能找到

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

我修改后的代码:链接: https://pan.baidu.com/s/1SNAxfFWQwPgAk4Vucor_RA?pwd=afsy 提取码: afsy 

 3.总结

Swin Transformer 是一种基于分区注意力机制和层次化结构的先进深度学习模型,通过在局部区域内进行自注意力计算以及使用窗口式注意力机制,实现了在图像分类和目标检测等任务上优异的性能表现。其Transformer缩放技术提高了模型的可扩展性和效率,使其能够处理大规模图像数据,并在训练和推理过程中保持高效率。综合而言,Swin Transformer以其创新性的设计和卓越的性能,成为处理图像数据的一种领先模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1677079.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# OpenCvSharp Demo - 最大内接圆

C# OpenCvSharp Demo - 最大内接圆 目录 效果 项目 代码 下载 效果 项目 代码 using OpenCvSharp; using System; using System.Diagnostics; using System.Drawing; using System.Drawing.Imaging; using System.Linq; using System.Windows.Forms; namespace OpenCvSh…

YOLOv5独家改进:backbone改进 | 微软新作StarNet:超强轻量级Backbone | CVPR 2024

💡💡💡创新点:star operation(元素乘法)在无需加宽网络下,将输入映射到高维非线性特征空间的能力,这就是StarNet的核心创新,在紧凑的网络结构和较低的能耗下展示了令人印象深刻的性能和低延迟 💡💡💡如何跟YOLOv5结合:替代YOLOv5的backbone 收录 YOL…

分享一个基于Qt的Ymodem的上位机(GitHub开源)

文章目录 1.项目地址2.Ymodem 协议介绍3.文件传输过程4.使用5.SecureCRT 软件也支持Ymodem6.基于PyQt5的Ymodem界面实现案例 1.项目地址 https://github.com/XinLiGH/SerialPortYmodem 基于VS2019 Qt5.15.2 编译,Linux下编译也可以,这里不做说明。 2.…

ROS2+TurtleBot3+Cartographer+Nav2实现slam建图和导航

0 引言 入门机器人最常见的应用就是slam建图和导航,本文将详细介绍这一流程, 便于初学这快速上手。 首先对需要用到的软件包就行简单介绍。 turtlebot3: 是一个小型的,基于ros的移动机器人。 学习机器人的很多示例程序都是基于turtlebot3。 …

51 单片机[2-1]:点亮一个LED

一、在 Keil5 中新建项目 打开 Keil5 ,点击 Project —— new μVision Project 新建文件夹 KeilProject ,以后的项目都在这个文件夹下,再建一个文件夹 2-1 点亮一个LED。在该文件夹下创建名为 Project 的文件,并保存。推荐起这…

金万维动态域名小助手怎么用?

金万维动态域名小助手是一个域名检测工具,使用此工具可以进行检测域名解析是否正确、清除DNS缓存、修改DNS服务器地址及寻找在线客服(仅支持付费用户)等操作。对不懂网络的用户是一个很好的检测域名的工具,下面我就讲解一下金万维…

TimesFM: 预训练的时间序列基础模型

大模型技术论文不断,每个月总会新增上千篇。本专栏精选论文重点解读,主题还是围绕着行业实践和工程量产。若在阅读过程中有些知识点存在盲区,可以回到如何优雅的谈论大模型重新阅读。另外斯坦福2024人工智能报告解读为通识性读物。若对于如果…

根据Word文档用剪映批量自动生成视频发布抖音

手头有大量word文档,想通过剪映的AI图文成片功能批量生成视频,发布到抖音平台,简单3步即可: 第一步:把word文档或者PDF等文档转成txt文本,可以用一些软件,也可以用AI工具,具体常见文…

Windows下编译RTTR

虽然C11引入了RTTI、Metaprogramming 等技术,但C在Reflection编程方面依旧功能有限。在社区上,RTTR则提供了一套C编写的反射库,补充了C在Reflection方面的缺陷。 零、环境 操作系统Windows 11Visual StudioVisual Studio Community 2022 CMa…

Qt---Socket通信

一、TCP/IP通信 在Qt中实现TCP/IP服务器端通信的流程: 1. 创建套接字 2. 将套接字设置为监听模式 3. 等待并接受客户端请求 可以通过QTcpServer提供的void newConnection()信号来检测是否有连接请求,如果有可以在对应的槽函数中调用nextPendingCon…

【现代C++】范围库的应用

C20引入了范围库(Ranges library),它是标准模板库(STL)的一个扩展,提供了一种新的方式来处理序列和范围。这个库允许开发者以更声明式的方式编写代码,使得操作序列变得更简洁、更易读。以下是C范…

【web网页开发制作】Html+Css+Js游戏主题特效及轮播效果网页作业天涯明月刀(7页面附源码)

HTMLCSSJS游戏主题轮播效果 🍔涉及知识🥤写在前面✨特效展示特效1、轮播幻灯效果特效2和3、鼠标悬浮及点击效果 🍧一、网页主题🌳二、网页效果Page1、首页Page2、游戏简介Page3、新闻中心Page4、互动专区Page5、视听盛宴Page6、用…

Kotlin协程实战指南:解锁Android开发高效能新时代

前言 在移动互联网的狂飙突进之中,Android开发领域如同站在风口的勇士,不断接受技术迭代与创新的双重洗礼。在这个快速变化的市场里,用户对应用性能和体验的期待水涨船高,开发者们面临的挑战也越来越大:如何在功能的丰…

Dart 3.4 发布:Wasm Native Macros(宏)

Google I/O 的结束,除了 Flutter 3.22 的发布 ,Dart 3.4 也迎来了它是「史诗级」的更新,之所以这么说,就是因为 Wasm Native 的落地和 Macros 的实验性展示。 在此之前,其实我也提前整理过一些对应的内容,…

运维别卷系列 - 云原生监控平台 之 06.prometheus pushgateway 实践

文章目录 [toc]Pushgateway 简介Pushgateway 部署创建 svc创建 deployment Pushgateway 测试删除 Pushgateway 上对应 lable 的数据 Pushgateway 简介 WHEN TO USE THE PUSHGATEWAY Pushgateway 是一种中介服务,允许您从无法抓取的作业中推送指标。 The Pushgateway…

深入理解 npm、cnpm、npx、yarn 和 pnpm:JavaScript 包管理器的对比

在 JavaScript 的世界中,包管理器是一个重要的工具,它帮助我们管理、安装和升级项目的依赖。在这篇文章中,我们将深入探讨三个最流行的 JavaScript 包管理器:npm、yarn 和 pnpm。 npm(Node Package Manager&#xff0…

未来IT行业的模块化、学习与跨界融合

随着技术的快速发展,IT行业已成为推动全球经济和社会发展的核心动力。从云计算和大数据到人工智能(AI)和物联网,这些创新技术正在彻底改变我们的生活方式和工作模式。而在AI领域,尤其是人工智能生成内容(AI…

怎么识别数学公式?分享简单识别方法

怎么识别数学公式?在学术研究和日常工作中,数学公式无疑是一个常见且重要的元素。然而,手动输入复杂的数学公式往往既耗时又容易出错。幸运的是,随着科技的发展,现在我们有了一些高效的软件工具,可以帮助我…

奥维地图下载高清影像的两种方式!以及ArcGIS、QGIS、GlobalMapper、自编工具下载高清影像的方法推荐!

今天来介绍一下奥维互动地图是如何下载高清影像的,也不是多了不起的功能!有朋友问,加上这个软件确实用的人多。 下载的高清数据在ArcGIS中打开的效果! 开始介绍奥维之前我们也介绍一下我们之前介绍的几个方法,没有优劣…

IP代理网络协议介绍

在IP代理页面上,存在HTTP/HTTPS/Socks5三种协议。它们都是客户端与服务器之间交互的协议。 HTTP HTTP又称之为超文本传输协议,在因特网使用范围广泛。它是一种请求/响应模型,客户端向服务器发送请求,服务器解析请求后对客户端作出…