基于多尺度注意力网络单图像超分(MAN)

news2024/10/7 6:51:29

引言

Transformer的自注意力机制可以进行远距离建模,在视觉的各个领域表现出强大的能力。然而在VAN中使用大核分解同样可以得到很好的效果。这也反映了卷积核的发展趋势,从一开始的大卷积核到vgg中采用堆叠的小卷积核代替大卷积核。

 上图展现了MAN网络在同样的性能下具有更少的参数量。

提高模型性能,通常有三种方法:

  • 更大的数据集
  • 更好的训练策略
  • 更好的网络结构

引言部分作者介绍了RCAN,RDN,MSRN的网络结构。作者之所以会提出这样的网络结构是因为:首先,transformer在各个领域大放光彩,但是作者认为transformer中的自注意力机制具有二次复杂度(就是太复杂了)。而出现的VAN(大核注意力机制)简单的堆叠卷积同样可以达到远距离建模的作用。所以作者产生了idea,就是应用VAN中的LKA(大核分解)来组建网络结构。作者为了最大化LKA的作用,作者采取了transformer的结构,而不是应用RCAN的架构。我们都知道transformer中包含了自注意力和MLP。但是作者认为MLP结构太过于复杂,用于低级视觉处理任务有点大材小用,所以作者引入了GSAU.

总结来说:作者首先提出了多尺度大核注意力模块(MLKA),可以得到多尺度远程建模依赖,提高了模型的表示能力。然后作者将门控机制和空间注意力结合在一起,构建了简化了前馈网络,可以减少参数和计算。

模型结构

整个模型分为三个部分。

首先是浅层的特征的提取,有一个3*3的卷积构成。

 然后是级联的多尺度注意力模块(MAB),用于进一步提取特征。在级联的最后添加了一层LKAT,这时因为在传统的SR中都存在一个卷积层。但是,它在建立远程连接方面存在缺陷,因此限制了最终重建功能的代表性能力。为了从堆叠的mab中总结出更合理的信息,我们在tail模块中引入了7-9-1 LKA。实验结果表明,LKAT模块可以有效地聚合有用信息,并显着提高重建质量

最后是重建模块,有卷积层和亚像素卷积层构成,其中应用的残差连接。

损失函数

 MAB

作者受transformer的启发,重新思考SISR中用于特征提取的基本卷积块。

 如上图所示,MAB有多尺度大核注意力(MLKA)和门空间注意力单元(GSAU)组成。

 其中整个MAB可以总结为以下公式:

 \lambda ,f 分别表示可学习参数,和1*1的逐点卷积。这里的1*1卷积是为了保持维度相等。

MLKA(多尺度大内核注意力机制)

注意力机制可以迫使网络专注于关键信息,而忽略不相关的信息。以前的SR模型采用了一系列注意机制,包括通道注意 (CA) 和自注意 (SA),以获得更多的信息性特征。但是,这些方法无法同时吸收本地信息和远程依赖性,并且它们通常会考虑固定感受野的注意力图。在最新的视觉注意研究的启发下 [21],我们提出了多尺度大内核注意 (MLKA),通过结合大内核分解和多尺度学习来解决这些问题。具体来说,MLKA由三个主要功能组成: 用于建立相互依存关系的大核注意 (LKA),用于获得异质尺度相关性的多尺度机制以及用于动态重新校准的门控聚合

LKA(大内核注意力)

 LKA可以表示为以下公式:

 多尺度机制

为了增强LKA,作者引入了组卷积多尺度机制,可以获得更加全面的信息。

 这里意思是说将输入按照通道维度分为n份。对于每一组特征都作用一个LKA。

门控聚合(Gated aggregation. )

与许多高级计算机视觉任务不同,SR任务对膨胀卷积和分组卷积的容忍度较差。

如图4所示,虽然较大的LKA捕获较宽的像素响应,但是阻塞伪像出现在较大的LKA的生成的注意图中。

 对于第i组输入Xi,为了避免块效应,并了解更多本地信息,我们利用空间门通过以下方式将LKAi(·) 动态调整为MLKAi(·):

其中Gi(·) 是ai × ai深度卷积产生的第i门,LKAi(·) 是ai-bi-1分解的LKA。在图4中,我们提供门控聚合的视觉结果。可以观察到,从注意力图中删除了块效果,并且MLKAi更为合理。特别是,具有较大感受野的LKA对远距离依赖性的反应更多,而较小的LKA倾向于保留局部纹理。 

门控空间注意单元(Gated Spatial Attention Unit (GSAU) )

在Transformer中,前馈网络 (FFN) 是增强特征表示的重要组成部分。但是,具有宽中间通道的MLP对于SR来说太重了,尤其是对于大型图像输入而言。受 [444 23,34,24] 的启发,我们将简单的空间注意 (SSA) 和门控线性单元 (GLU) 集成到建议的GSAU中,以实现自适应门控机制并减少参数和计算。

为了更有效地捕获空间信息,我们采用单层深度卷积对特征图进行加权。给定X和Y,GSAU的关键过程可以表示为

 通过应用空间门,GSAU可以在考虑的复杂性下去除非线性层并捕获局部连续性。

Large Kernel Attention Tail (LKAT) 

在以前的SR网络 [7、8、26、27、9] 中,3 × 3卷积层被广泛用作深度提取主干的尾部。但是,它在建立远程连接方面存在缺陷,因此限制了最终重建功能的代表性能力。为了从堆叠的mab中总结出更合理的信息,我们在tail模块中引入了7-9-1 LKA。具体地,如图3所示,LKA由两个1 × 1卷积包裹。实验结果表明,LKAT模块可以有效地聚合有用信息,并显着提高重建质量

实验

数据集和指标

遵循最新的工作 [35,9,12],我们利用包含800和2650训练图像的DIV2K [36] 和Flicker2K [7] 来训练我们的模型。为了进行测试,我们在五个常用数据集上评估了我们的方法: Set5 [13],Set14 [37],BSD100 [38],Urban100 [39] 和Manga109 [40]。此外,在YCbCr图像的Y通道中应用了两个标准评估指标,即峰值信噪比 (PSNR) 和结构相似性指数 (SSIM) [41]。

实验细节

我们训练了三个不同版本的MAN: tiny,light和classical。三个版本的通道数和MAB的层数都不一样,通道数分别为48/60/180,层数分别为5/24/36.MLKA使用了三种多尺度分解模式,分别为3-5-1、5-7-1和7-9-1。GSAU中使用7 × 7深度卷积。

在训练阶段,使用双三次插值来生成lr-hr图像对。训练对通过水平翻转和90 、180 、270的随机旋转进一步增强。

消融实验

各个组件的研究

上表展现了MAN中各个组件对模型的贡献。其中LKAT用卷积层代替,多尺度用LKA(5-7-1),GSAU用MLP代替.因为作者是用的meta-transformer的架构,所以会有MLP模块。其中GSAU是一个很好的替代品对于MLP,在MAN-light中减少了15k的参数且模型性能不下降。

模型架构的选择研究

 上表表明了哪种风格的架构更加有效。从表中可以看出RCAN结构的具有更多的参数且模型性能比较差。

MLKA的研究

 上表展现出了作者设计了LKA和几个MLKA,表明MLKA具有更加好的性能和更快的收敛性。在表中可能存在一点疑问,就是模块数量增加,模块的参数反而没增加,这时因为作者在这里应用分组卷积。

GSAU的研究

 先进算法的对比

与一些最先进的小 [6,45,46] 和轻量级 [31,47,9] SR 进行对比。

 经典的SR方法对比。

 

 总结

在本文中,我们提出了一种多尺度注意力网络 (MAN),用于在多种环境下重新缩放超分辨率图像。MAN采用transformer以获得更好的建模表示能力。为了有效,灵活地在各个区域之间建立长期相关性,我们开发了结合大内核分解和多尺度机制的多尺度大内核关注 (MLKA)。此外,我们提出了一种简化的前馈网络,该网络集成了门机制和空间注意力,以激活本地信息并降低模型复杂性。广泛的实验表明,我们基于CNN的MAN可以以更有效的方式实现比以前的SOTA模型更好的性能。

模型代码:

# -*- coding: utf-8 -*-
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from basicsr.utils.registry import ARCH_REGISTRY

class LKA(nn.Module):
    """
    大核注意力机制,将(K,d)的大卷积层分解为三个卷积层
    """
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 7, padding=7 // 2, groups=dim)
        self.conv_spatial = nn.Conv2d(dim, dim, 9, stride=1, padding=((9 // 2) * 4), groups=dim, dilation=4)
        self.conv1 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)
        # 这里的*是矩阵的点乘、torch.mul(),矩阵乘法是@、torch.mm()
        return u * attn

class Attention(nn.Module):
    def __init__(self, n_feats):
        super().__init__()
        self.norm = LayerNorm(n_feats, data_format='channels_first')
        ###可学习参数
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
        self.proj_1 = nn.Conv2d(n_feats, n_feats, 1)
        self.spatial_gating_unit = LKA(n_feats)
        self.proj_2 = nn.Conv2d(n_feats, n_feats, 1)
    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(self.norm(x))
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x * self.scale + shorcut
        return x
    # ----------------------------------------------------------------------------------------------------------------


class MLP(nn.Module):
    def __init__(self, n_feats):
        super().__init__()

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

        i_feats = 2 * n_feats

        self.fc1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
        self.act = nn.GELU()
        self.fc2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0)

    def forward(self, x):
        shortcut = x.clone()
        x = self.norm(x)
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)

        return x * self.scale + shortcut


class CFF(nn.Module):
    def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor=15, attn='GLKA'):
        super().__init__()
        i_feats = n_feats * 2

        self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
        self.DWConv1 = nn.Sequential(
            nn.Conv2d(i_feats, i_feats, 7, 1, 7 // 2, groups=n_feats),
            nn.GELU())
        self.Conv2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0)

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

    def forward(self, x):
        shortcut = x.clone()

        # Ghost Expand
        x = self.Conv1(self.norm(x))
        x = self.DWConv1(x)
        x = self.Conv2(x)

        return x * self.scale + shortcut


class SimpleGate(nn.Module):
    def __init__(self, n_feats):
        super().__init__()
        i_feats = n_feats * 2

        self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
        # self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats)
        self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

    def forward(self, x):
        shortcut = x.clone()

        # Ghost Expand
        x = self.Conv1(self.norm(x))
        a, x = torch.chunk(x, 2, dim=1)
        x = x * a  # self.DWConv1(a)
        x = self.Conv2(x)

        return x * self.scale + shortcut
    # -----------------------------------------------------------------------------------------------------------------


# RCAN-style
class RCBv6(nn.Module):
    def __init__(
            self, n_feats, k, lk=7, res_scale=1.0, style='X', act=nn.SiLU(), deploy=False):
        super().__init__()
        self.LKA = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 5, 1, lk // 2, groups=n_feats),
            nn.Conv2d(n_feats, n_feats, 7, stride=1, padding=9, groups=n_feats, dilation=3),
            nn.Conv2d(n_feats, n_feats, 1, 1, 0),
            nn.Sigmoid())

        # self.LFE2 = LFEv3(n_feats, attn ='CA')

        self.LFE = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 3, 1, 1),
            nn.GELU(),
            nn.Conv2d(n_feats, n_feats, 3, 1, 1))

    def forward(self, x, pre_attn=None, RAA=None):
        shortcut = x.clone()
        x = self.LFE(x)

        x = self.LKA(x) * x

        return x + shortcut

    # -----------------------------------------------------------------------------------------------------------------


class MLKA_Ablation(nn.Module):
    def __init__(self, n_feats, k=2, squeeze_factor=15):
        super().__init__()
        i_feats = 2 * n_feats

        self.n_feats = n_feats
        self.i_feats = i_feats

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

        k = 2

        # Multiscale Large Kernel Attention
        self.LKA7 = nn.Sequential(
            nn.Conv2d(n_feats // k, n_feats // k, 7, 1, 7 // 2, groups=n_feats // k),
            nn.Conv2d(n_feats // k, n_feats // k, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // k, dilation=4),
            nn.Conv2d(n_feats // k, n_feats // k, 1, 1, 0))
        self.LKA5 = nn.Sequential(
            nn.Conv2d(n_feats // k, n_feats // k, 5, 1, 5 // 2, groups=n_feats // k),
            nn.Conv2d(n_feats // k, n_feats // k, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // k, dilation=3),
            nn.Conv2d(n_feats // k, n_feats // k, 1, 1, 0))
        '''self.LKA3 = nn.Sequential(
            nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k),  
            nn.Conv2d(n_feats//k, n_feats//k, 5, stride=1, padding=(5//2)*2, groups=n_feats//k, dilation=2),
            nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))'''

        # self.X3 = nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k)
        self.X5 = nn.Conv2d(n_feats // k, n_feats // k, 5, 1, 5 // 2, groups=n_feats // k)
        self.X7 = nn.Conv2d(n_feats // k, n_feats // k, 7, 1, 7 // 2, groups=n_feats // k)

        self.proj_first = nn.Sequential(
            nn.Conv2d(n_feats, i_feats, 1, 1, 0))

        self.proj_last = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 1, 1, 0))

    def forward(self, x, pre_attn=None, RAA=None):
        shortcut = x.clone()

        x = self.norm(x)

        x = self.proj_first(x)

        a, x = torch.chunk(x, 2, dim=1)

        # u_1, u_2, u_3= torch.chunk(u, 3, dim=1)
        a_1, a_2 = torch.chunk(a, 2, dim=1)

        a = torch.cat([self.LKA7(a_1) * self.X7(a_1), self.LKA5(a_2) * self.X5(a_2)], dim=1)

        x = self.proj_last(x * a) * self.scale + shortcut

        return x
    # -----------------------------------------------------------------------------------------------------------------


class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class SGAB(nn.Module):
    def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor=15, attn='GLKA'):
        super().__init__()
        i_feats = n_feats * 2

        self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
        self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7 // 2, groups=n_feats)
        self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

    def forward(self, x):
        shortcut = x.clone()

        # Ghost Expand
        x = self.Conv1(self.norm(x))
        a, x = torch.chunk(x, 2, dim=1)
        x = x * self.DWConv1(a)
        x = self.Conv2(x)

        return x * self.scale + shortcut


class GroupGLKA(nn.Module):
    def __init__(self, n_feats, k=2, squeeze_factor=15):
        super().__init__()
        i_feats = 2 * n_feats

        self.n_feats = n_feats
        self.i_feats = i_feats

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

        # Multiscale Large Kernel Attention
        self.LKA7 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
        self.LKA5 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
        self.LKA3 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))

        self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)
        self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)
        self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)

        self.proj_first = nn.Sequential(
            nn.Conv2d(n_feats, i_feats, 1, 1, 0))

        self.proj_last = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 1, 1, 0))

    def forward(self, x, pre_attn=None, RAA=None):
        shortcut = x.clone()

        x = self.norm(x)

        x = self.proj_first(x)

        a, x = torch.chunk(x, 2, dim=1)

        a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)

        a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],
                      dim=1)

        x = self.proj_last(x * a) * self.scale + shortcut

        return x

    # MAB


class MAB(nn.Module):
    def __init__(
            self, n_feats):
        super().__init__()

        self.LKA = GroupGLKA(n_feats)

        self.LFE = SGAB(n_feats)

    def forward(self, x, pre_attn=None, RAA=None):
        # large kernel attention
        x = self.LKA(x)

        # local feature extraction
        x = self.LFE(x)

        return x


class LKAT(nn.Module):
    def __init__(self, n_feats):
        super().__init__()

        # self.norm = LayerNorm(n_feats, data_format='channels_first')
        # self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

        self.conv0 = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 1, 1, 0),
            nn.GELU())

        self.att = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 7, 1, 7 // 2, groups=n_feats),
            nn.Conv2d(n_feats, n_feats, 9, stride=1, padding=(9 // 2) * 3, groups=n_feats, dilation=3),
            nn.Conv2d(n_feats, n_feats, 1, 1, 0))

        self.conv1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)

    def forward(self, x):
        x = self.conv0(x)
        x = x * self.att(x)
        x = self.conv1(x)
        return x


class ResGroup(nn.Module):
    def __init__(self, n_resblocks, n_feats, res_scale=1.0):
        super(ResGroup, self).__init__()
        self.body = nn.ModuleList([
            MAB(n_feats) \
            for _ in range(n_resblocks)])

        self.body_t = LKAT(n_feats)

    def forward(self, x):
        res = x.clone()

        for i, block in enumerate(self.body):
            res = block(res)

        x = self.body_t(res) + x

        return x


class MeanShift(nn.Conv2d):
    def __init__(
            self, rgb_range,
            rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False


@ARCH_REGISTRY.register()
class MAN(nn.Module):
    def __init__(self, n_resblocks=36, n_resgroups=1, n_colors=3, n_feats=180, scale=2, res_scale=1.0):
        super(MAN, self).__init__()

        # res_scale = res_scale
        self.n_resgroups = n_resgroups

        self.sub_mean = MeanShift(1.0)
        self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1)

        # define body module
        self.body = nn.ModuleList([
            ResGroup(
                n_resblocks, n_feats, res_scale=res_scale)
            for i in range(n_resgroups)])

        if self.n_resgroups > 1:
            self.body_t = nn.Conv2d(n_feats, n_feats, 3, 1, 1)

        # define tail module
        self.tail = nn.Sequential(
            nn.Conv2d(n_feats, n_colors * (scale ** 2), 3, 1, 1),
            nn.PixelShuffle(scale)
        )
        self.add_mean = MeanShift(1.0, sign=1)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)
        res = x
        for i in self.body:
            res = i(res)
        if self.n_resgroups > 1:
            res = self.body_t(res) + x
        x = self.tail(res)
        x = self.add_mean(x)
        return x

    def visual_feature(self, x):
        fea = []
        x = self.head(x)
        res = x

        for i in self.body:
            temp = res
            res = i(res)
            fea.append(res)

        res = self.body_t(res) + x

        x = self.tail(res)
        return x, fea

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))

参考文献:

icandle/MAN: Multi-scale Attention Network for Single Image Super-Resolution (github.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1408.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用T0,方式2,在P1.0输出周期为400µs,占空比为4:1的矩形脉冲,要求在P1.0引脚接有虚拟示波器,观察P1.0引脚输出的矩形脉冲波形

大家学过一段时间的单片机了,今天我们来说说单片机里的定时器,又叫计数器。首先,我们通过案例来了解一下什么是定时器。 【例】使用T0,方式2,在P1.0输出周期为400s,占空比为4:1的矩形脉冲&…

如何编写优秀的测试用例,建议收藏和转发

1、测试点与测试用例 测试点不等于测试用例,这是我们首先需要认识到的。 问题1:这些测试点在内容上有重复,存在冗余。 问题2:一些测试点的测试输入不明确,不知道测试时要测试哪些。 问题3:总是在搭相似…

串口通信协议【I2C、SPI、UART、RS232、RS485、CAN】

(1)I2C 集成电路互连总线接口(Inter IC):同步串行半双工传输总线,连接嵌入式处理器及其外围器件。 支持器件:LCD驱动器、Flash存储器 特点: ①有两根传输线(时钟线SCL、双向数据线SDA&#…

python基础19-36题

题目: 代码十九二十二十一二十二二十三二十四二十五二十六二十七二十八二十九三十三十一三十二三十三三十四三十五三十六十九 birthday int(input(“请输入生日日期:”)) Set1 [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31] Set2 [2,3,6,7,10,11,…

【CV】第 7 章:目标检测基础

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

几何等变图神经网络综述

许多科学问题都要求以几何图形(geometric graphs)的形式处理数据。与一般图数据不同,几何图显示平移、旋转和反射的对称性。研究人员利用这种对称性的归纳偏差(inductive bias),开发了几何等变图神经网络&a…

SpringMVC | 快速上手SpringMVC

👑 博主简介:    🥇 Java领域新星创作者    🥇 阿里云开发者社区专家博主、星级博主、技术博主 🤝 交流社区:BoBooY(优质编程学习笔记社区) 前言:在上一节中我们了解…

多分类评估指标计算

文章目录混淆矩阵回顾Precision、Recall、F1回顾多分类混淆矩阵宏平均(Macro-average)微平均(Micro-average)加权平均(Weighted-average)总结代码混淆矩阵回顾 若一个实例是正类,并且被预测为正…

Linux(Nginx)

目录 一、Nginx简介 二、Nginx使用 Nginx安装 tomcat负载均衡 Nginx配置 三、Nginx部署项目 项目打包前 将前端项目打包(测试本地项目打包后没问题) ip/host主机映射 完成Nginx动静分离的default.conf的相关配置 将前台项目打包(配合Nginx动静…

real-word super resulution: real-sr, real-vsr, realbasicvsr 三篇超分和视频超分论文

real-world image and video super-resolution 文章目录real-world image and video super-resolution1. Toward Real-World Single Image Super-Resolution:A New Benchmark and A New Model(2019)1.1 real-world数据集制作1.2 LP-KPN网络结构1.3 拉普拉…

近八成中国程序员起薪过万人民币,你过了么?

打工者联盟为了抵抗996、拖欠工资、黑心老板、恶心公司,让我们组成打工者联盟。客观评价自己任职过的公司情况,为其他求职者竖起一座引路的明灯。https://book.employleague.cn/一项调查显示,近八成中国程序员本科毕业生起薪过万(…

Oracle数据库中的数据完整性

目录 1.数据完整性约束作用 2.数据完整性约束的分类 3.完整性约束的状态 4.域完整性的实现 (1)check约束 ①可视化方式创建check约束 ②命令方式创建约束 ③修改表创建的约束 ④删除约束 (2)实体完整性约束实现 ①prim…

思科dhcp服务器动态获取ip地址

项目要求: 某公司共有网管中心、行政部、技术部、三个部门,分别处在一栋大楼中的两个楼层,为了保证公司内部主机始终能够连接Internet,采用双向冗余设计,分别使用路由器R1与路由器R2连接中国电信和中国联通。 1.首先为了避免不必要…

【算法详解】数据结构:7种哈希散列算法,你知道几个?

一、前言 哈希表的历史 哈希散列的想法在不同的地方独立出现。1953 年 1 月,汉斯彼得卢恩 ( Hans Peter Luhn ) 编写了一份IBM内部备忘录,其中使用了散列和链接。开放寻址后来由 AD Linh 在 Luhn 的论文上提出。大约在同一时间,IBM Researc…

项目进度管理

第3 章 项目进度管理 3.1 概述 1.项目进度管理是指在项目实施过程中,对各阶段的进展程度和项目最终完成的期限所进行的管理,是在 规定的时间内,拟定出合理且经济的进度计划(包括多级管理的子汁划),在执行该计划的过程…

常见的限流算法的原理以及优缺点

原文网址:常见的限流算法的原理以及优缺点_IT利刃出鞘的博客-CSDN博客 简介 说明 本文介绍限流常用的算法及其优缺点。 常用的限流算法有: 计数器(固定窗口)算法滑动窗口算法漏桶算法令牌桶算法 下面将对这几种算法进行分别介绍…

tmux的简单使用

文章目录一、认识tmux1.1 会话1.2 tmux的作用1.3 tmux的安装二、tmux的使用2.1 会话管理2.1.1 创建会话2.1.2 退出会话2.1.3 从终端环境进入会话2.1.4 查看会话列表2.1.5 销毁会话2.1.6 重命名会话2.2 窗口管理2.3 窗格管理一、认识tmux 1.1 会话 命令行的典型使用方式是&…

rocketmq是如何消费

拉取消息的请求都在pullRequestQueue队列里, 拉取消息成功后设置下一次需要拉取的offset, boolean dispatchToConsume processQueue.putMessage(pullResult.getMsgFoundList()); 这个方法会把拉取回来的消息放进msgTreeMap里面 然后消费拉取回来的消…

MongoDB副本集成员如何复制新数据

复制是指在多台服务器上保持相同的数据副本。MongoDB 实现此功能的方式是保存操作日志(oplog),其中包含了主节点执行的每一次写操作。oplog 是存在于主节点 local 数据库中的一个固定集合。从节点通过查询此集合以获取需要复制的操作。 每个…

Solving Inverse Problems With Deep_Neural Networks – Robustness Included_

作者:Martin Genzel, Jan Macdonald, and Maximilian Marz期刊:preprint arXiv时间:2020代码链接:代码论文链接:论文 1 动机与研究内容 最近工作发现深度神经网络对于图像重构的不稳定(instabilities),以…