论文阅读 | Video Super-Resolution Transformer

news2024/11/23 15:12:26

引言:2021年用Transformer实现视频超分VSR的文章,改进了SA并在FFN中加入了光流引导
论文:【here】
代码:【here】

Video Super-Resolution Transformer

引言

视频超分中有一组待超分的图片,因此视频超分也经常被看做是一个序列问题。这种序列问题的解决方法通常有RNN,SLTM,和transformer。由于transformer并不需要递归也更适合,并备受关注

注意力机制

目前用transformer处理图片的思路是全连接注意力机制,the fully connected self-attention(FCSA)
(这里的这个FCSA的概念是作者提出来的,作者举例VIT和PIT都是FCSA,因此我把它当做对整个图像的分成的块做自注意力)
然而,作者认为这样的FCSA的机制并不能很好的提取空间局部信息,但是局部信息对于VSR来说又是很关键的。
此外,除了空间局部信息,时域信息也是很重要的,视频中的图片中的信息可以通过相邻的图片进行补充。现在,该如何用transformer来处理时域信息也是没有被探索过的(这个领域还没有人做)

前馈网络
现有的前馈网络token-wise feed-forward layer不能实现图像之间的对齐,这里强调token,即是指全连接都是在每一个token中实现的,token和token之间没有关联。token 之间的特征关联在FCSA模块中实现的,但是在FFN中没有特征传播。因此,在这个模块中,作者实现了以像元为单位(而不是token),实现了特征传播和特征对齐

问题定义

第一个定义是映射函数的loss
在这里插入图片描述
第二个定义是神经元组成/参数传播的定义,同时这个连接了不同神经元的映射与真实映射之间的loss应该小于一个epsilon
在这里插入图片描述
第三个定义为视频超分的定义和目标
在这里插入图片描述
第四个定义为transformer的架构
在这里插入图片描述
(这一块有点枯燥,主要是作者的第二个定义,神经元的参数传播为后面的公式推导奠定基础)

视频超分 Transformer

作者介绍了这样一个公式,来证明FCSA不太适合视频超分tranformer(公式没有看懂,我这里就跳过了)
在这里插入图片描述
总之,作者通过这个公式论证了全连接注意力机制FCSA会导致梯度消失的问题

When q is not sufficiently large, the fully connected attention layer may result in the gradient vanishing issue. It implies that the gradient descent will be “stuck” upon the initialization, and thus will fail to learn the k-pattern function. Therefore, the fully connected self-attention layer cannot use the spatial information of each frame since the local information is not encoded in the embeddings of all tokens. Moreover, this issue may become more serious when directly using such layers in video super-resolution.

而如果用作者提出的,则时空卷积自注意力机制STCSA很好的解决了这个问题
在这里插入图片描述

STCSA的实现
即将图片划成8 * 8 * 5的小块(这里的5是指连续的图片数),在8 * 8 * 5的3D块中实现块中的像元单位的特征自注意力
在这里插入图片描述
同时,作者还加入了一个3D位置编码信息,编码规则如下
在这里插入图片描述

代码(作者非常贴心的加上了尺寸的注释)

class globalAttention(nn.Module):
    def __init__(self, num_feat=64, patch_size=8, heads=1):
        super(globalAttention, self).__init__()
        self.heads = heads
        self.dim = patch_size ** 2 * num_feat
        self.hidden_dim = self.dim // heads
        self.num_patch = (64 // patch_size) ** 2
        
        self.to_q = nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, groups=num_feat) 
        self.to_k = nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, groups=num_feat)
        self.to_v = nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1)

        self.conv = nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1)

        self.feat2patch = torch.nn.Unfold(kernel_size=patch_size, padding=0, stride=patch_size)
        self.patch2feat = torch.nn.Fold(output_size=(64, 64), kernel_size=patch_size, padding=0, stride=patch_size)

    def forward(self, x):
        b, t, c, h, w = x.shape                                # B, 5, 64, 64, 64
        H, D = self.heads, self.dim
        n, d = self.num_patch, self.hidden_dim

        q = self.to_q(x.view(-1, c, h, w))                     # [B*5, 64, 64, 64]    
        k = self.to_k(x.view(-1, c, h, w))                     # [B*5, 64, 64, 64]   
        v = self.to_v(x.view(-1, c, h, w))                     # [B*5, 64, 64, 64]

        unfold_q = self.feat2patch(q)                          # [B*5, 8*8*64, 8*8]
        unfold_k = self.feat2patch(k)                          # [B*5, 8*8*64, 8*8]  
        unfold_v = self.feat2patch(v)                          # [B*5, 8*8*64, 8*8] 

        unfold_q = unfold_q.view(b, t, H, d, n)                # [B, 5, H, 8*8*64/H, 8*8]
        unfold_k = unfold_k.view(b, t, H, d, n)                # [B, 5, H, 8*8*64/H, 8*8]
        unfold_v = unfold_v.view(b, t, H, d, n)                # [B, 5, H, 8*8*64/H, 8*8]

        unfold_q = unfold_q.permute(0,2,3,1,4).contiguous()    # [B, H, 8*8*64/H, 5, 8*8]
        unfold_k = unfold_k.permute(0,2,3,1,4).contiguous()    # [B, H, 8*8*64/H, 5, 8*8]
        unfold_v = unfold_v.permute(0,2,3,1,4).contiguous()    # [B, H, 8*8*64/H, 5, 8*8]

        unfold_q = unfold_q.view(b, H, d, t*n)                 # [B, H, 8*8*64/H, 5*8*8]
        unfold_k = unfold_k.view(b, H, d, t*n)                 # [B, H, 8*8*64/H, 5*8*8]
        unfold_v = unfold_v.view(b, H, d, t*n)                 # [B, H, 8*8*64/H, 5*8*8]

        attn = torch.matmul(unfold_q.transpose(2,3), unfold_k) # [B, H, 5*8*8, 5*8*8]
        attn = attn * (d ** (-0.5))                            # [B, H, 5*8*8, 5*8*8]
        attn = F.softmax(attn, dim=-1)                         # [B, H, 5*8*8, 5*8*8]

        attn_x = torch.matmul(attn, unfold_v.transpose(2,3))   # [B, H, 5*8*8, 8*8*64/H]
        attn_x = attn_x.view(b, H, t, n, d)                    # [B, H, 5, 8*8, 8*8*64/H]
        attn_x = attn_x.permute(0, 2, 1, 4, 3).contiguous()    # [B, 5, H, 8*8*64/H, 8*8]
        attn_x = attn_x.view(b*t, D, n)                        # [B*5, 8*8*64, 8*8]
        feat = self.patch2feat(attn_x)                         # [B*5, 64, 64, 64]
        
        out = self.conv(feat).view(x.shape)                    # [B, 5, 64, 64, 64]
        out += x                                               # [B, 5, 64, 64, 64]

        return out

这样就完全考虑8 * 8感受野内的局部特征了,但是块与块之间的边缘只能朝一个方向进行特征传播
于是作者提出了一种新型的FFN

feed-forward Network实现
作者首先将5张图片中的相邻光流求出来,如果边缘图像的另一边没有图了,则跟自己作光流
在这里插入图片描述
这样可以得到5 * 2张光流图(这里的5指视频图片数),每张图片都有它的前向流图和后向流图,然后前向warp和后向warp后,原有的每张图的时间位置上都可以多加两张图,分别来自前一时刻图片前向warp得到,和后一时刻的图片后向warp得到
在这里插入图片描述
然后通过两组图片的融合,即生成了最终结果
在这里插入图片描述
值得一提的是,这里的FFN和传统FFN不同,由于前面的SA部分保留的图片的原有尺寸,这里的FFN直接用3*3卷积实现

class FeedForward(nn.Module):
    def __init__(self, num_feat):
        super().__init__()
        
        self.backward_resblocks = ResidualBlocksWithInputConv(num_feat+3, num_feat, num_blocks=30)
        self.forward_resblocks = ResidualBlocksWithInputConv(num_feat+3, num_feat, num_blocks=30)
        self.fusion = nn.Conv2d(num_feat*2, num_feat, 1, 1, 0, bias=True)
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        
    def forward(self, x, lrs=None, flows=None):
        b, t, c, h, w = x.shape
        x1 = torch.cat([x[:, 1:, :, :, :], x[:, -1, :, :, :].unsqueeze(1)], dim=1)  # [B, 5, 64, 64, 64]
        flow1 = flows[1].contiguous().view(-1, 2, h, w).permute(0, 2, 3, 1)         # [B*5, 64, 64, 2]
        x1 = flow_warp(x1.view(-1, c, h, w), flow1)                                 # [B*5, 64, 64, 64]
        x1 = torch.cat([lrs.view(b*t, -1, h, w), x1], dim=1)                        # [B*5, 67, 64, 64]
        x1 = self.backward_resblocks(x1)                                            # [B*5, 64, 64, 64]

        x2 = torch.cat([x[:, 0, :, :, :].unsqueeze(1), x[:, :-1, :, :, :]], dim=1)  # [B, 5, 64, 64, 64]
        flow2 = flows[0].contiguous().view(-1, 2, h, w).permute(0, 2, 3, 1)         # [B*5, 64, 64, 2]
        x2 = flow_warp(x2.view(-1, c, h, w), flow2)                                 # [B*5, 64, 64, 64]
        x2 = torch.cat([lrs.view(b*t, -1, h, w), x2], dim=1)                        # [B*5, 67, 64, 64]
        x2 = self.forward_resblocks(x2)                                             # [B*5, 64, 64, 64]

        # fusion the backward and forward features
        out = torch.cat([x1, x2], dim=1)      # [B*5, 128, 64, 64]
        out = self.lrelu(self.fusion(out))    # [B*5, 64, 64, 64]
        out = out.view(x.shape)               # [B, 5, 64, 64, 64] 

        return out

实验

在这里插入图片描述
在这里插入图片描述
在别的文章里有看到块与块之间会出现伪影,然而文章的结果挺完美的

总结

用transformer解决VSR的问题,虽然在空间小范围内进行attention是可行的,也不会造成太大的计算量,但是总觉得对于transformer的优势没有发挥出来,大的感受野和全局信息的利用才是transformer的优势所在

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/346839.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【网络原理4】TCP特性篇

目录 一、滑动窗口 传统发送接收机制的缺点 滑动窗口的特性 发送方什么时候会接着发送下一条报文 如果在滑动窗口的机制下面发生了丢包会怎样处理 情况1:ack丢了 情况2:主动发送的syn丢包了 滑动窗口的应用场景 二、TCP流量控制:根据接…

大数据下Flink on YarnSession 高可用集群环境部署开辟资源发布任务

前言:搭建大数据环境集群环境算是比较麻烦的一个事情,并且对硬件要求也比较高其中搭建大数据环境需要准备jdk环境和zk环境,还有hdfs,还有ssh之间的免密操作,还有主机别名访问不通的问题 等。必然会出现的问题&#xff…

拆个微波炉,分析一下电路

微波炉是用2450MHz的超高频电磁波来加热食品,它能无损穿越塑料,陶瓷,不能穿越金属,碰到金属会反射,但穿过含水食物,食物内的分子会高速摩擦,产生热量,使食物变熟。在厨房电器中&…

自学180天,我从功能测试进阶到自动化测试了...

大家好,我是彭于晏,一个7年测试工作的老司机。因为我一直在分享自动化测试技术,所以时常会被问到这个问题:“靓仔,功能测试想转自动化测试,请问要怎么入手?” 那么,接下来我就结合自…

不愧是GitHub点赞飙升的Java10W字面经,面面俱到,太全了!

最新的喜报啊,话不多说,先看图!(为了保护朋友的隐私,同时还有我自己的隐私,楼主就都打码了~!) 朋友说到这儿时候我就跟他说,不要只看眼前,要看长远一些&#…

前端开发之防抖与节流

前端开发中我们经常会通过监听某些事件来完成项目需求 1.通过监听 scroll 事件,检测滚动位置,根据滚动位置显示返回顶部按钮 2.通过监听 resize 事件,对某些自适应页面调整DOM的渲染(通过CSS实现的自适应不再此范围内)…

动态库和静态库的区别

什么是库文件 一般来说,一个程序,通常都会包含目标文件和若干个库文件。经过汇编得到的目标文件再经过和库文件的链接,就能构成可执行文件。库文件像是一个代码仓库或代码组件的集合,为目标文件提供可直接使用的变量、函数、类等…

Hadoop3.3.0--Linux编译安装

Hadoop3.3.0–Linux编译安装 基础环境:Centos 7.7 编译环境软件安装目录 mkdir -p /export/server一、Hadoop编译安装(选做) 可以直接使用课程提供已经编译好的安装包。 安装编译相关的依赖 yum install gcc gcc-c make autoconf automake…

leaflet 上传CSV文件,导出geojson格式文件(064)

第064个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+leaflet中加载CSV文件,将图形显示在地图上。点击导出geojson,下载成geojson文件。 直接复制下面的 vue+openlayers源代码,操作2分钟即可运行实现效果. 文章目录 示例效果配置方式示例源代码(共114行)安装插件…

如何判断是否ChatGPT回答出来的问题?解决方法详解

目录 前言 一、人工智能(“ChatGPT”等)能淘汰人类吗? 二、完全禁止或严格限制使用ChatGPT,是利大于弊还是? 1、ChatGPT与造纸术优点 2、人有悲欢离合,月有阴晴圆缺,此事古难全&#xff01…

Python基础-数据类型之序列

序列:一种数据结构,序列中的每个元素都会被分配到一个序号(元素的位置)。 常用的序列有:列表、元组、字符串。 一、序列的操作: 1:通过索引取值 nums_list [1,2,3,4] print(nums_list[0]) …

消息队列的特点

一、背景:在分布式系统中是如何处理高并发的由于在高并发的环境下,来不及同步处理用户发送的请求,则会导致请求发送阻塞。比如说,大量的insert、update之类的请求同时到达数据库MYSQL,直接导致无数的行锁表锁&#xff…

零基础机器学习做游戏辅助第十一课--原神自动钓鱼(一)

一、序言 前面我们已经学习了神经网络,卷积神经网络和强化学习等内容,也都做了对应的实例。但是我们的课是做游戏辅助,那么肯定要去游戏里实战一番。 今天就带领大家用我们所学的知识对近两年非常火爆的游戏《原神》进行实战。我们以自动钓鱼为例。 二、观察游戏玩法制定方案…

Seata源码学习(五)- Seata服务端(TC)源码解读

Seata源码分析- Seata服务端(TC)源码解读 上节课我们已经分析到了SQL语句最终的执行器,但是再往下分析之前,我们需要先来分析一下TM客户端与TC端通讯以后,TC端的具体操作 服务端表解释 我们的Seata服务端在应用的时…

RabbitMq及其他消息队列

消息队列中间价都有哪些 先进先出 Kafka、Pulsar、RocketMQ、RabbitMQ、NSQ、ActiveMQ 架构 消费推拉模式 客户端消费者获取消息的方式,Kafka和RocketMQ是通过长轮询Pull的方式拉取消息,RabbitMQ、Pulsar、NSQ都是通过Push的方式。 pull类型的消息队…

OpenCV制作Mask图像掩码

一、掩膜(mask) 在有些图像处理的函数中有的参数里面会有mask参数,即此函数支持掩膜操作,首先何为掩膜以及有什么用,如下: 数字图像处理中的掩膜的概念是借鉴于PCB制版的过程,在半导体制造中&am…

PowerShell Install VNC-Server VNC-Viewer

前言 VNCConnect是一款屏幕共享、远程控制电脑软件,可以让您连接到世界上任何地方的远程计算机,实时观看其屏幕,并像坐在它前面一样进行控制。RealVNC可以将人和设备连接到任何地方,实现控制、支持、管理、监控、培训、协作等等。…

Java——不同的子序列

题目链接 leetcode在线oj题——不同的子序列 题目描述 给定一个字符串 s 和一个字符串 t ,计算在 s 的子序列中 t 出现的个数。 字符串的一个 子序列 是指,通过删除一些(也可以不删除)字符且不干扰剩余字符相对位置所组成的新…

【C语言学习笔记】:数组、指针相关面试题

无特殊说明情况下,下面所有题s目都是linux下的32位C程序。 「1、计算以下sizeof的值。」 char str1[] {a, b, c, d, e}; char str2[] "abcde";char *ptr "abcde";char book[][80]{"计算机应用基础","C语言","C程…

Apple Safari 16.3 - macOS 专属免费浏览器 (独立安装包免费下载)

Safari 浏览器 16 for macOS Montery, Big Sur 请访问原文链接:https://sysin.org/blog/apple-safari-16/,查看最新版。原创作品,转载请保留出处。 作者主页:www.sysin.org 之前 Safari 浏览器伴随 macOS 更新一起发布&#xff…