详解SwinIR的论文和代码(SwinIR: Image Restoration Using Swin Transformer)

news2024/11/24 17:22:17

paper:https://arxiv.org/abs/2108.10257
code:https://github.com/JingyunLiang/SwinIR

目录

  • 1. Swin Transformer layers
    • 1.1 局部注意力
    • 1.2 移动窗口机制
    • 1.3 关键代码理解
  • 2. 整体网络结构
    • 2.1 浅层特征提取
    • 2.2 深层特征提取
    • 2.3 图像重建
  • 3.总结

SwinIR将Swin transformer1应用到low level领域的图像增强任务,结合卷积设计了网络结构,在以下三个任务上取得了很好的效果:图像超分辨率(包括classical、lightweight和real-world SR)、图像去噪(包括灰度图和彩色图像去噪)和 JPEG压缩失真去除。本文将结合代码对SwinIR进行详解。

SwinIR的网络结构并不复杂,关键部件就是Swin Transformer layers(STL)卷积层残差连接。卷积和残差连接大家都比较熟悉了,因此我首先结合代码介绍一下swin transformer层,然后自底向上的介绍SwinIR的全貌

1. Swin Transformer layers

SwinIR使用的Swin Transformer layers(STL)是在swin transformer中提出的,并未有改动。STL基于原始的多头注意力transformer层进行优化,主要的不同点在于:1. 局部注意力(local attention);2. 移动窗口机制(shifted window mechanism);

1.1 局部注意力

原始的全局注意力会将图像分成若干个patch,所有的patch之间做自注意力计算;所谓的局部注意力就是首先将图像划分成若干个window,每个window内在进行patch的划分,然后在window内部进行自注意力的计算,而不在一个window内的patch是没有交互的。也就是说,只考虑一个window内的patch,他们之间的计算和全局注意力操作是一样的。

理解局部注意力具体是怎么做的,很好的一个办法是看代码和分析tensor在不同层之间的shape整理出来。下面是我整理的tensor shape变化:

请添加图片描述
其中,b: batchsize, h: 输入高, w:输入宽, ws: 窗口大小, C: channel数, num_heads:attention的head数

1.2 移动窗口机制

由于基于窗口的多头注意力(W-MSA)没有考虑跨窗口的连接,模型建模长距离关联的能力受损。因此swin transformer提出了移动窗口多头注意力机制(SW-MSA),可在保证计算高效性的前提下,扩大感受野。

如下图所示,W-MSA的窗口大小为M*M(图中M=4),那么SW-MSA的窗口划分将向右下移动 ⌊ M / 2 ⌋ ∗ ⌊ M / 2 ⌋ \lfloor M/2 \rfloor *\lfloor M/2 \rfloor M/2M/2

请添加图片描述

但是经过位移之后,窗口数量会变多,由原来的 ⌊ h / M ⌋ ∗ ⌊ w / M ⌋ \lfloor h/M \rfloor *\lfloor w/M \rfloor h/Mw/M变成 ( ⌊ h / M ⌋ + 1 ) ∗ ( ⌊ w / M ⌋ + 1 ) (\lfloor h/M \rfloor + 1) *(\lfloor w/M \rfloor +1) (⌊h/M+1)(⌊w/M+1),而且窗口大小不一致。因此swin transformer提出了循环位移,减少窗口数量,同时可以获得相同大小的窗口进行并行计算。循环位移如下图所示。
请添加图片描述

在代码中,循环位移通过torch.roll实现,shifts为负,代表从下往上移动,从右往左移动,最上和最左循环移动到最下和最右。

# cyclic shift
if self.shift_size > 0:
    shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

关于torch.roll可参考:https://blog.csdn.net/weixin_42899627/article/details/116095067
如上图所示,经过循环移位后,有三个窗口中有一些patch是本不相邻的,它们不应该做自注意力,所以swin transformer建立了mask机制来完成最终的注意力计算。

关于mask的理解可参考https://github.com/microsoft/Swin-Transformer/issues/38

1.3 关键代码理解

下面来看一下关键代码及注释,首先是WindowAttention的forward函数:

def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape  # 此处的输入是经过window partition的
        # self.qkv(x): num_windows*B, window_size*window_size, 3*C
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # 通过一个全连接层获取所有头的qkv,(3, num_windows*B, num_heads, window_size*window_size, C // num_heads)
        q, k, v = qkv[0], qkv[1], qkv[2] 
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1)) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
		# 可学习的相对位置bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0) # num_windows*B, num_heads, window_size*window_size, window_size*window_size

        if mask is not None:
            nW = mask.shape[0]
            # 将mask和attn相加,mask只有两种取值0和-100,因此为0时对attn无影响,为-100时,self.softmax(attn)将变为接近于0
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
            attn = self.softmax(attn) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        # v:num_windows*B, num_heads, window_size*window_size, C // num_heads
        # attn:num_windows*B, num_heads, window_size*window_size, window_size*window_size
        # attn @ v: num_windows*B, num_heads, window_size*window_size, C // num_heads
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # num_windows*B, window_size*window_size, C
        x = self.proj(x) # 全连接层
        x = self.proj_drop(x)
        return x

接下来是SwinTransformerBlock的forward函数

    def forward(self, x, x_size):
        H, W = x_size
        B, L, C = x.shape
        # assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # (num_windows*B, window_size, window_size, C)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # num_windows*B, window_size*window_size, C

        # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
        if self.input_resolution == x_size:
            attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
        else:
            attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

可以看到每个SwinTransformerBlock内部完成的是:
X = M S A ( L N ( X ) ) + X X = MSA(LN(X)) + X X=MSA(LN(X))+X
X = M L P ( L N ( X ) ) + X X = MLP(LN(X)) + X X=MLP(LN(X))+X
其中MSA为W-MSA和SW-MSA交替。

2. 整体网络结构

请添加图片描述
如上图所示,SwinIR包括三个modules,浅层特征提取、深层特征提取和图像重建。其中特征提取模块对所有任务都是一样的,但是图像重建对于不同的任务是不同的。

2.1 浅层特征提取

一个3×3卷积层将特征图通道转成embed_dim:(b, embed_dim, h, w)

self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

2.2 深层特征提取

深层特征提取的基本模块则是第一节中讲解的STL和卷积层和残差连接。STL和卷积组成RSTB,RSTB和卷积组成了深层特征提取。

2.3 图像重建

以下代码可以看到对于不同的任务,图像重建模块是不同的,有的采用最邻近插值+卷积,有的采用pixelshuffle+卷积,有的直接采用卷积。


if self.upsampler == 'pixelshuffle':
    # for classical SR
    x = self.conv_first(x)
    x = self.conv_after_body(self.forward_features(x)) + x
    x = self.conv_before_upsample(x)
    x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
    # for lightweight SR
    x = self.conv_first(x)
    x = self.conv_after_body(self.forward_features(x)) + x
    x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
    # for real-world SR
    x = self.conv_first(x) # (b, embed_dim, h, w)
    x = self.conv_after_body(self.forward_features(x)) + x
    x = self.conv_before_upsample(x)
    x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
    x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
    x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
    # for image denoising and JPEG compression artifact reduction
    x_first = self.conv_first(x)
    res = self.conv_after_body(self.forward_features(x_first)) + x_first
    x = x + self.conv_last(res)

SwinIR可以很灵活配置网络的复杂度。影响W-MSA计算复杂度: 4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2hwC 4hwC2+2M2hwC
请添加图片描述

3.总结

  1. 结构简单,性能全面超过cnn-based的方法,适用于多种任务,可做为Low-level的基线模型;
  2. 作者发现与以往基于transformer的方法不同,Swinir不需要比cnn更多的训练数据,收敛速度也更快;
  3. 结构模块化,可以方便调整出不同复杂度的模型;

  1. Liu Z, Lin Y, Cao Y, et al. Swin transformer: Hierarchical vision transformer using shifted windows[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 10012-10022. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1229939.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

服务器探针-serverstatus

{alert type"info"} 之前给大家介绍过一个简单的服务器监控。uptime-kuma 今天给各位带来一个酷炫的多服务器探针和多服务器监控。ServerStatus {/alert} 作者的开源项目地址如下:https://github.com/cppla/ServerStatus 作者的项目体验地址如下 https://…

运动耳机什么牌子好?十大运动蓝牙耳机品牌排行榜

​运动耳机需求各有不同,对于我们每个人来说,选择最适合自己的耳机是一项重要任务。在这个耳机类型繁多,五花八门的时代,如何找到一款适合自己的运动耳机呢?选对运动耳机很重要,所以接下来安利五款相当不错…

element-ui中怎样使用iconfont的图标

1 登录 https://www.iconfont.cn/ 2 搜索合适的图 这里可以找到这个图所在的图库。这样就可以一次查找到对应的所有同款图标 3 选择同款加入购物车 4 将购物车的icon加入项目,注意是新建项目,除非你是确定需要前面已经加过的icon 5 下载icon 选择fon…

SpringBoot的启动流程

一、SpringBoot是什么? springboot是依赖于spring的,比起spring,除了拥有spring的全部功能以外,springboot无需繁琐的xml配置,这取决于它自身强大的自动装配功能;并且自身已嵌入Tomcat、Jetty等web容器&am…

课程设计:C++实现哈夫曼编码

功能实现: //1:先计算每个字符的权重//2:构建哈夫曼树//3:得出每个字符的哈夫曼编码。//4:根据哈夫曼编码转化为字符 代码实现: // 哈夫曼编码.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。 //1:先计…

vue动态配置路由

文章目录 前言定义项目页面格式一、vite 配置动态路由新建 /router/utils.ts引入 /router/utils.ts 二、webpack 配置动态路由总结如有启发,可点赞收藏哟~ 前言 项目中动态配置路由可以减少路由配置时间,并可减少配置路由出现的一些奇奇怪怪的问题 路由…

如何将文字、图片、视频、链接等内容生成一个二维码?

通过二维彩虹的【H5编辑】功能,就可以将文字、图片、视频、文件、链接等多种格式的内容编辑在一个页面,然后生成一个自定义的二维码——H5编辑二维码。扫描后,即可查看二维码中的详细图文视频等内容了。这个功能大受欢迎! 这个H5…

深度学习之基于CT影像图像分割检测系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 基于CT影像的图像分割检测系统可以被设计成能够自动地检测出CT图像中的病变部位或解剖结构,以协助医生进…

庖丁解牛:NIO核心概念与机制详解 05 _ 文件锁定

文章目录 Pre概述锁定文件 (lock)Code文件锁定和可移植性 Pre 庖丁解牛:NIO核心概念与机制详解 01 庖丁解牛:NIO核心概念与机制详解 02 _ 缓冲区的细节实现 庖丁解牛:NIO核心概念与机制详解 03 _ 缓冲区分配、包装和…

calibre更新 环境变量设置

我这里是从别的地方copy过来的calibre,所以不用安装。 如果需要安装请参考: Caibre2022.3_17版本安装及遇到问题 - 梅希的日志 - EETOP 创芯网论坛 (原名:电子顶级开发网) -将copy过来的calibre放在原来calibre的位置。 打开工作路径下的.b…

【Vue】Vue3 超简单拖拽条动态修改容器宽度

demo 代码 const leftBoxWidth ref(200); // 默认宽度 const leftResize (e: MouseEvent) > {const startX e.clientX;const startWidth leftBoxWidth.value;const mouseMove (documentE: MouseEvent) > {// 80 是左侧菜单宽度leftBoxWidth.value startWidth docu…

不懂找伦敦银趋势?3个方法搞定

趋势是我们的朋友,但是这个朋友却很喜欢跟我们开玩笑,如果我们不留意,根本发觉不了它的存在。怎么找到趋势本体并且和它做个好朋友呢?下面我们就来介绍三个方法。 数波段的高点和低点。我们以当前的市场波动价格为轴,向…

快手运营的必备的10个工具

一、引言 快手作为短视频领域的佼佼者,为众多创作者提供了广阔的舞台。要想在快手运营中取得成功,掌握一些必备的工具是必不可少的。本文将为您介绍快手运营的10个必备工具,帮助您提高工作效率,优化内容创作。 二、工具推荐 1. …

现货白银MACD实战分析例子

MACD这个技术指标的全称是平滑异同移动平均线,主要表示经过平滑处理后均线的差异程度,一般用来研判现货白银价格变化的方向、强度和趋势。MT4中的MACD指标,主要是由信号线、(上升/下跌)动能柱、0轴这三部分组成。 MACD…

键盘映射笔记

dumpkeys命令用于显示当前系统中定义的键盘映射表。它可以帮助用户查看和理解系统中的键盘布局和键盘映射规则。 当用户执行dumpkeys命令时,它会读取系统中的键盘映射表文件(通常是/etc/keymaps或/etc/console/boottime.kmap.gz),…

chatglm-6B模型下载

从huggingface上面下载chatglm-6B模型是比较简捷的方式,下面记录一下下载安装过程。 huggingface的官方文档如下: https://huggingface.co/docs/huggingface_hub/v0.14.0.rc1/guides/download 1.配置conda环境 服务器上使用的是miniconda,…

如何在公网环境下使用笔记本的Potplayer访问本地群晖webdav中的影视资源

文章目录 如何在公网环境下使用笔记本的Potplayer访问本地群晖webdav中的影视资源**那么问题来了,potplayer只能局域网内访问资源,那我不在家中怎么看本地电影?** 本教程解决的问题是:按照本教程方法操作后,达到的效果…

什么是域欺骗?域欺骗的主要类型有哪些?

域欺骗是指网络犯罪分子假冒网站名称或电子邮件域来欺骗用户。域欺骗的目的是将恶意电子邮件或网络钓鱼网站伪装成合法电子邮件或网站,诱使用户与之交互。域欺骗就像骗子一样,向人们展示伪造的凭据以获得信任,然后再利用其获得好处。 域欺骗…

【PCB学习】几种接地符号

声明 该图并非原创,原文出处不可考,因此在这里附加说明。 示意图

在vue-cli中快速使用webpack-bundle-analyzer

webpack-bundle-analyzer 是一个可视化资源分析工具,可以直观地分析打包出的文件有哪些,及它们的大小、占比情况、各文件 Gzip压缩后的大小、模块包含关系、依赖项等。 从vue-cli官方的更新记录可以看到,从vue-cli3开始集成report命令 当前环…