Swin Transformer代码实战篇

news2024/11/17 11:48:15

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊往期回顾:CV攻城狮入门VIT(vision transformer)之旅——近年超火的Transformer你再不了解就晚了! CV攻城狮入门VIT(vision transformer)之旅——VIT原理详解篇 CV攻城狮入门VIT(vision transformer)之旅——VIT代码实战篇

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

Swin Transformer代码实战篇

写在前面

​  上一篇我们已经介绍了Swin Transformer的原理,对此还不了解的点击☞☞☞了解详情。此篇文章参考B站UP霹雳吧啦Wz 的视频,大家若对Swin Transformer代码没有一点基础,建议先去观看视频。有一说一,这位UP的视频质量做的是真高,到目前为止,我已经不知道推荐过多少次了。但是呢,这部分视频时间确实长,有的地方也难以听懂,所以我听了20分钟就听不下去了,于是自己慢慢的调试起代码,这个过程挺漫长也挺难的,但是你坚持下来就会有所收获。当然了,光靠我慢慢摸索代码并没有把整个框架都弄清楚,仍然存在许多搞不明白的地方。这时候我就又观看了一篇视频,二刷的感觉明显不一样,UP说到的点基本都能理解了。但还是存在一些疑难杂症,后来又进一步调试摸索,最后基本都弄明白了。🥤🥤🥤

​  说这些,只是为大家提供一个学习代码的路线,具体怎么做,还是仁者见仁智者见智,只要找到最符合你习惯的就好。这篇文章不会把每句代码都讲的十分详细,重点会挑一些我觉得理解起来有一定难度,UP也没有细讲的点,所以此篇文章和UP主的视频更配喔!!!🍟🍟🍟

​  准备好了嘛,开始发车!!!🚖🚖🚖

 

模型整体设计框架

​  为方便大家理解代码,我画出了代码中几个关键的类,如下图:

​  首先,最大的一个类就是SwinTransformer,它定义了整个Swin Transformer的框架。接着是BasicLayer类,它是Swin Transformer Block和Patch Merging的组合。【注意,代码中是Swin Transformer Block+patch merging组合在一起,而不是理论部分的Patch merging+Swin Transformer Block】 然后是SwinTransformer Block类,它定义了Swin Transformer的结构。还有一个是WindowAttention类,它定义了W-MSA和SW-MSA结构。

 

Patch partition+Linear Embedding实现

​  和ViT相同,这部分采用一个卷积实现,代码如下:

## 定义PatchEmbed类
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)  #定义卷积
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # padding
        # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions,
            # (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

 

Patch Merging实现

​  这部分原理在上一篇已经详细介绍,代码如下:

## 定义PatchMerging类
class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # padding
        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

​  关于这部分,稍难理解的是这部分代码,如下图所示:

image-20220823101633470

​  这几行代码就对应我们理论部分所说的划分成四个小patch。以x0 = x[:, 0::2, 0::2, :]为例,它表示取所以Batch和Chanel的数据,从H的第0位和W的第0位开始取,行列都每隔两个取一个数据。其它三个表达的含义类似。

​  上面这样解释不知道大家能否听懂,我再举个例子,代码如下:【这里忽略了Batch和Chanel维度】

import torch
x= [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]
x = torch.tensor(x)

​ 这样我们定义了一个四行四列的元素,来看一下其结果:

image-20220823102336800

​ 接着,我们对上述x进行切片,代码如下:

x0 = x[0::2, 0::2]
x1 = x[1::2, 0::2]
x2 = x[0::2, 1::2]
x3 = x[1::2, 1::2]

​ 此时,我们来看看x0、x1、x2、x3的输出结果,如下图所示:

image-20220823102607042

相信通过这个例子大家就一目了然了。🥂🥂🥂

 

SW-MSA

​ 这部分我主要讲讲窗口移动的代码,其实就一行,如下图所示:

x = torch.roll(shifted_x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

​ 这行代码到底干了什么呢?我们同样以一个例子来讲解,如下:

import torch
x= [[1,2,3,4],[5,6,7,8],[9,10,11,12],[13,14,15,16]]
x = torch.tensor(x)

​ 先定义一个四行四列的元素,我们打印出x看一看:

image-20220823104230636

​ 接着我们执行这行代码:

shifted_x1 = torch.roll(x, shifts=(-1, -1), dims=(0, 1))

​ 来看看输出的shifted_x1结果:

image-20220823104344150

​  是不是发现就是先将x的第一行移动到最后一行,然后将第一列移动到最后一列的结果呢。是不是发现代码实现这一步非常的简单呢。至于self.shift_size为 ⌊ M 2 ⌋ \left\lfloor {\frac{{\rm{M}}}{2}} \right\rfloor 2M,M为窗口大小。【注意:只有在SW-MSA是才使用此步骤】

​  我们在理论部分谈到,执行完SW-MSA后,要将移动后的窗口还原回去,代码也很简单,就是一个反向的移动,如下:

x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

​ 我们在来通过刚刚的例子理解一下:

shifted_x2 = torch.roll(shifted_x1, shifts=(1, 1), dims=(0, 1))

​ 来看看输出的shifted_x2结果:

image-20220823105226293

​ 会发现shifted_x2和原始的x是一致的!!!🥗🥗🥗

 

训练结果展示

以下结果为花的五分类训练结果:

  • 使用预训练模型:swin_tiny_patch4_window7_224.pth ,一共训练10轮,结果如下:

image-20220913115442306

  • 使用预训练模型:swin_base_patch4_window7_224_in22k ,一共训练10轮,结果如下:

image-20220913115527591

  • 不使用预训练模型:swin_base_patch4_window7_224_in22k ,一共训练10轮,结果如下:

image-20220913115618993

​ 通过上面几个实验可以看出,swin Transformer的效果还是很不错的,特别是使用了预训练模型后。


​ 我也在swin transformer的代码中尝试加上可学习的位置编码,发现效果较之前也有一定的提升,如下:

  • 使用预训练模型:swin_tiny_patch4_window7_224.pth,一共训练10轮 ,加入可学习位置编码。

image-20220913120115588

 

小结

​  这部分就写这么多了,用文字来讲解代码感觉确实有难度,所以后期可能会打算出一些视频教学,当然这都是后话了。本篇其实主要就为大家整理了两个点,通过两个例子帮助大家进行理解。其它的内容相信你通过调试或者看我推荐的视频是可以解决的,最后希望大家学有所成。🌼🌼🌼

 

参考链接

使用Pytorch搭建Swin-Transformer网络 🍁🍁🍁

 
 

如若文章对你有所帮助,那就🛴🛴🛴

         一键三连 (1).gif

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/173259.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

列表元素 有序列表 无序列表 定义列表 ol: ul: dl: dd: dt:

目录列表元素有序列表无序列表定义列表列表元素 有序列表 ol: ordered list 表示整个列表 li: list item 表示单个列表,列表的子元素 reversed: 导则 列表的写法: 但实际开发中一般不用type来设置列表的序,而是用css 把大象装冰箱分几步…

SpringBoot 项目

不得不佩服 Spring Boot 的生态如此强大,今天我给大家推荐几款 Gitee 上优秀的后台开源版本的管理系统,小伙伴们再也不用从头到尾撸一个项目了,简直就是接私活,挣钱的利器啊。SmartAdmin我们开源一套漂亮的代码和一套整洁的代码规…

2.4、进程通信

整体框架 1、什么是进程通信? 顾名思义,进程通信就是指进程之间的信息交换。 进程是分配系统资源的单位(包括内存地址空间), 因此各进程\color{red}各进程各进程拥有的内存地址空间相互独立\color{red}内存地址空间相互独立内存…

知识分享-商业数据分析业务全流程

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

力扣sql简单篇练习(二)

力扣sql简单篇练习(二) 1 从不订购的客户 1.1 题目内容 1.1.1 基本题目信息 1.1.2 示例输入输出 1.2 示例sql语句 # 一个人也是有可能下多个订单的 SELECT name Customers FROM Customers WHERE id not in(SELECT distinct Customerid FROM Orders)1.3 运行截图 2 删除重…

【LINUX修行之路】工具篇——Vim的使用及配置

🍿本节主题:vim的使用 🎈推荐阅读:回溯算法 、C入门(上篇) 💕我的主页:蓝色学者的主页 文章目录一、前言二、文本编辑器和IDE三、选择vim的理由四、vim操作模式4.1普通模式(command …

springboot项目实现腾讯云的短信验证

前言:可以先去看下腾讯云开通短息服务需要哪些信息。我这里使用自己很久之前申请过的公众号,其他的比如网站,小程序啥的也没有,哈哈哈~。 腾讯云地址: https://console.cloud.tencent.com/smsv2/csms-sign/create接下…

非父子组件的通信

在开发中,我们构建了组件树之后,除了父子组件之间的通信之外,还会有非父子组件之间的通信。这里我们主要讲两种方式: Provide/Inject;全局事件总线; 1、Provide和Inject Provide/Inject用于非父子组件之间…

什么是OAuth2

2.3 什么是OAuth2 2.3.1 OAuth2认证流程 在前边我们提到微信扫码认证,这是一种第三方认证的方式,这种认证方式是基于OAuth2协议实现, OAUTH协议为用户资源的授权提供了一个安全的、开放而又简易的标准。同时,任何第三方都可以使…

离散数学-图论-图的基本概念(11)

图的基本概念 1 图 1.1 图的定义 定义1&#xff1a; 一个无向图G是一个有序的二元组<V,E>&#xff0c;其中 &#xff08;1&#xff09;V是一个非空有穷集&#xff0c;称为顶点集&#xff0c;其元素称为顶点或结点。 &#xff08;2&#xff09;E是无序积V&V的有穷多…

什么是安卓版 UI 业务包 SDK?如何接入?

涂鸦 Android 业务包是指包含业务逻辑和 UI 界面的涂鸦垂直业务模块&#xff0c;旨在为基于涂鸦智能生活 SDK 开发的应用提供快速的一站式接入涂鸦业务模块的能力。 概述 目前提供的业务包种类繁多&#xff0c;例如&#xff1a; H5 商城设备配网设备控制IP 摄像机智能场景常…

“深度学习”学习日记。与学习相关的技巧 -- 参数的更新

2023.1.20 在神经网络的学习这一章&#xff0c;学习过了利用 梯度下降法 对参数进行更新&#xff0c;目的是找到是损失函数的值尽量小的参数&#xff1b;像解决这样的问题称为 最优化 。 由于参数空间十分复杂、参数规模十分庞大&#xff0c;导致“最优化”的过程变得困难。 …

C规范编辑笔记(十二)

往期文章&#xff1a; C规范编辑笔记(一) C规范编辑笔记(二) C规范编辑笔记(三) C规范编辑笔记(四) C规范编辑笔记(五) C规范编辑笔记(六) C规范编辑笔记(七) C规范编辑笔记(八) C规范编辑笔记(九) C规则编辑笔记(十) C规范编辑笔记(十一) 正文&#xff1a; 放假了&#xff…

【数据结构】万字深入浅出讲解顺序表(附原码 | 超详解)

&#x1f680;write in front&#x1f680; &#x1f4dd;个人主页&#xff1a;认真写博客的夏目浅石. &#x1f381;欢迎各位→点赞&#x1f44d; 收藏⭐️ 留言&#x1f4dd; &#x1f4e3;系列专栏&#xff1a;C语言实现数据结构 &#x1f4ac;总结&#xff1a;希望你看完…

智能矿山电子封条系统 YOLOv5

智能矿山电子封条系统通过yolov5深度学习技术&#xff0c;对现场画面进出口以及主要的井口等重要地方对矿井人员变化、生产作业执勤状态及出入井人员等状况实时监控分析监测。我们使用YOLO(你只看一次)算法进行对象检测。YOLO是一个聪明的卷积神经网络(CNN)&#xff0c;用于实时…

Google AIY Vision Kit安装及国内配置

Google AIY Vision Kit安装及国内配置1. AIY Vision Kit组装环节Step 1&#xff1a;收集其他附件选择1&#xff1a;使用AIY项目应用程序选择2&#xff1a;使用显示器、鼠标和键盘Step 2&#xff1a;检查硬件清单Step 3&#xff1a;构建AIY Vision KitStep 3.1&#xff1a;获取最…

旺店通·企业奇门和用友BIP接口打通对接实战

旺店通企业奇门和用友BIP接口打通对接实战接通系统&#xff1a;旺店通企业奇门旺店通是北京掌上先机网络科技有限公司旗下品牌&#xff0c;国内的零售云服务提供商&#xff0c;基于云计算SaaS服务模式&#xff0c;以体系化解决方案&#xff0c;助力零售企业数字化智能化管理升级…

Mac和Windows局域网互传文件iPhone和Windows局域网互传文件

生活中&#xff0c;我们可以通过微信和QQ或网盘等等传输工具进而实现文件互传&#xff0c;但是面临一个问题&#xff0c;大文件无法上传&#xff0c;而且受到网速的限制等诸多因素影响&#xff0c;如今我们可以通过局域网进行实现文件互传&#xff0c;进而改变此种囧境。 首先在…

17道Redis 面试题

Redis 持久化机制缓存雪崩、缓存穿透、缓存预热、缓存更新、缓存降级等问题热点数据和冷数据是什么Memcache与Redis的区别都有哪些&#xff1f;单线程的redis为什么这么快redis的数据类型&#xff0c;以及每种数据类型的使用场景&#xff0c;Redis 内部结构redis的过期策略以及…

KVM安装部署 | 举例安装虚机Windows2012R2

目录 1、基础环境准备 2、KVM的安装 3、开启服务 4、开启图形化界面 5、也可以通过浏览器管理KVM 6、举例安装一个windows2012R2 1、基础环境准备 【关闭防火墙】 systemctl stop firewalld systemctl disable firewalld 【关闭selinux】 修改文件/etc/selinux/config…