【低照度图像增强系列(3)】EnlightenGAN算法详解与代码实现

news2025/1/9 15:46:50

前言  

☀️ 在低照度场景下进行目标检测任务,常存在图像RGB特征信息少提取特征困难目标识别和定位精度低等问题,给检测带来一定的难度。

     🌻使用图像增强模块对原始图像进行画质提升,恢复各类图像信息,再使用目标检测网络对增强图像进行特定目标检测,有效提高检测的精确度。

      ⭐本专栏会介绍传统方法、Retinex、EnlightenGAN、SCI、Zero-DCE、IceNet、RRDNet、URetinex-Net等低照度图像增强算法。

👑完整代码已打包上传至资源→低照度图像增强代码汇总资源-CSDN文库

目录

前言  

🚀一、EnlightenGAN介绍 

☀️1.1 EnlightenGAN简介

☀️1.2 EnlightenGAN网络结构

(1)生成器模块  

(2)判别器模块  

(3)损失函数  

🚀二、EnlightenGAN核心代码讲解

🎄2.1 Functions

🎄2.2 Class

🚀三、EnlightenGAN源码运行

 

🚀一、EnlightenGAN介绍 

相关资料: 

  • EnlightenGAN 论文:https://arxiv.org/abs/1906.06972
  • EnlightenGAN 论文详细解读:《EnlightenGAN: Deep Light Enhancement withoutPaired Supervision》论文超详细解读(翻译+精读)
  • EnlightenGAN 源码:https://github.com/VITA-Group/EnlightenGAN

☀️1.1 EnlightenGAN简介

目前,基于深度学习的低照度图像增强方法取得了一些不错的成效。但是一直以来存在着一个问题,就是它们大部分都属于监督学习,也就是说需要大量配对数据(paired data)来进行训练,但现实生活中,我们很难获取大量的同场景下的低光和正常光图像来作为数据对。

因此,作者和他的团队提出了一种无监督生成对抗网络来实现图像增强,即EnlightenGAN。这个模型并不需要配对数据来进行训练,但却能在多种场景下表现良好。为了提高模型性能,同时也弥补数据未成对造成的一些不足,作者和他的团队提出了一系列的新处理方法,包括全局-局部判别器结构,自正则化感知损失,以及自正则注意机制。


☀️1.2 EnlightenGAN网络结构

下图是EnlightenGAN网络结构。

EnlightenGAN网络结构 = 生成器(带自注意力机制的U-Net)+ 判别器(全局-局部鉴别器)

(1)生成器模块  

首先,我们来看看生成器模块

生成器模块就是一个引入了自注意力机制的U-Net,自正则化注意力图的生成方式如下:

  1. 把输入的RGB图像转为灰度图

  2. 将灰度图(I)归一化到 [ 0,1 ]

  3. 运算1 - I(element-wise difference 逐元素作差),突出暗部部分

  4. 得到了注意力图(attention map),重点关注暗部部分

可以理解为对于光照越弱的地方注意力越强。因为网络中得到的每个特征图大小都不一样,所以这里把注意力图resize为各中间的特征图对应的大小,然后对应相乘最后得到了我们的输出图像。

整个U-Net 生成器由8个卷积块组成,每个卷积块由两个3*3的卷积层一个BN层和LeakReLU层。

为什么把ReLU层换为LeakyReLU层?

由于稀疏梯度虽然在大多数网络中通常是理想的目标,但是在GAN中,它会妨碍训练过程,影响GAN的稳定性,所以作者的网络中没有maxpool层和ReLU层,而是用LeakReLU层替代ReLU层。

此外,为了减小棋盘效应,作者用一个双线性上采样层一个卷积层来代替原本的标准反卷积层。

棋盘效应:由于反卷积的”不均匀重叠“,会导致图像中的某部位比别的部位颜色深,造成的伪影看上去像棋盘格子一般。而这种”不均匀重叠“,是因为卷积核(kernel)尺寸不能被步长(stride)整除导致的。


(2)判别器模块  

  • 全局鉴别器:上面的灰色块,判断生成的图像和真实图像之间的整体光照差异,改善图像的全局光照特征对抗性损失来最小化真实图像和输出图像的光照分布的距离。但全局鉴别器,对于一些暗场景下存在明亮区域的图像,适应性不够
  • 局部鉴别器:下面的灰色块,判断生成的图像和真实图像之间的局部细节差异。改善图像的细节特征,用的 PatchGAN来鉴别真/假 来鉴别真/假。从输出图像和真实图像中随机采样 5 个图像块(上图),来判断是真实图像还是模型增强出来的图像。解决全局鉴别器带来的局部曝光不足或过度的情况了。

(3)损失函数  

相对论鉴别器函数:

  • C:表示网络
  • x_{r} 和x_{f} :是从真实的和伪分布中采样的
  • \sigma:表示S形函数

全局鉴别器D和生成器G的损失函数:

局部鉴别器D和生成器G的损失函数:

自特征保持损失LSFP定义:

  • I^{L}表示输入低光图像
  • G(I^{L})表示生成器的增强输出
  • \phi _{i,j}表示从ImageNet上预训练的VGG16模型中提取的特征图
  • i表示第i个最大池化层
  • j表示第i个最大池化层之后的第j个卷积层
  • W _{i,j}H _{i,j}是提取的特征图的维度

EnlightenGAN的整体损失函数:


🚀二、EnlightenGAN核心代码讲解

这一部分我们主要讲EnlightenGAN模型的网络生成器这部分的核心,也就是models文件夹中的networks.py

🎄2.1 Functions

① pad_tensor

def pad_tensor(input):
    height_org, width_org = input.shape[2], input.shape[3] #获取张量的高度和宽度
    divide = 16

    if width_org % divide != 0 or height_org % divide != 0:# 判断输入张量的宽度和高度是否不能被divide整除

        width_res = width_org % divide
        height_res = height_org % divide
        if width_res != 0:
            width_div = divide - width_res # 需要填充的宽度
            pad_left = int(width_div / 2) # 填充的左侧宽度
            pad_right = int(width_div - pad_left) # 填充的右侧宽度
        else:
            pad_left = 0
            pad_right = 0

        if height_res != 0:
            height_div = divide - height_res # 需要填充的高度
            pad_top = int(height_div / 2) # 填充的左侧高度
            pad_bottom = int(height_div - pad_top) # 填充的右侧高度
        else:
            pad_top = 0
            pad_bottom = 0

        padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom)) # 在输入张量的四个边上进行反射填充
        input = padding(input)
    else:
        pad_left = 0
        pad_right = 0
        pad_top = 0
        pad_bottom = 0

    height, width = input.data.shape[2], input.data.shape[3]
    assert width % divide == 0, 'width cant divided by stride'
    assert height % divide == 0, 'height cant divided by stride'

    return input, pad_left, pad_right, pad_top, pad_bottom

这段代码的主要作用是对输入的二维张量进行填充,以确保其高度和宽度能够被指定的divide参数整除。

具体而言,该函数执行以下操作:

  1. 如果输入张量的宽度或高度不能被divide整除,计算需要进行填充的数量,并使用反射填充(nn.ReflectionPad2d)对输入进行填充。
  2. 如果宽度和高度已经能够被divide整除,则不进行填充。
  3. 返回填充后的张量以及进行填充的左、右、上、下四个方向的填充量。

主要参数含义: 

  • width_org height_org 是输入张量的原始宽度和高度。
  • divide 是用于指定张量宽度和高度整除性的参数。
  • pad_leftpad_rightpad_top pad_bottom 是填充的左、右、上、下四个方向的填充量。

② pad_tensor_back

def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom):
    height, width = input.shape[2], input.shape[3]
    return input[:, :, pad_top: height - pad_bottom, pad_left: width - pad_right]

这段代码主要作用是与前面 pad_tensor 函数相对应的逆操作,用于反向去除填充。这个函数的目的是从填充后的张量中截取出原始尺寸的部分

具体来说,函数通过切片操作,从填充后的张量中截取出原始尺寸(不包括填充的部分)的子张量。返回的结果就是去除填充后的张量,恢复到原始尺寸的部分。

这样的操作通常在对图像或特征图进行处理后,需要将其还原到原始尺寸时使用。这可以确保在网络的前向传播和反向传播过程中,输入和输出的尺寸保持一致。


③ weights_init

def weights_init(m):
    classname = m.__class__.__name__ # 初始化权重
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02) # 卷积层权重正态分布初始化,均值为0,标准差为0.02
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02) # 批量归一化层正态分布初始化
        m.bias.data.fill_(0) # 批量归一化层偏置项设置为0

这段代码主要作用是初始化神经网络模型中的权重。具体来说,它对卷积层和批量归一化层的权重进行初始化。

函数通过遍历模型的每个模块(m),根据模块的类别进行不同的权重初始化。

具体做法如下:

  • 如果模块属于卷积层,则将卷积层的权重进行正态分布初始化,均值为0,标准差为0.02。
  • 如果模块属于批量归一化层,则将批量归一化层的权重进行正态分布初始化,均值为1,标准差为0.02,并将偏置项设置为0。

(这样的初始化策略有助于在训练初期使得权重处于较小的范围,有助于网络的稳定训练。这是一种常见的初始化方法,尤其在使用卷积和批量归一化的深度学习模型中。)


④ get_norm_layer

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
    elif norm_type == 'synBN':
        norm_layer = functools.partial(SynBN2d, affine=True)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm)
    return norm_layer

这段代码主要作用是返回指定类型的归一化层。归一化层在深度学习中用于提高训练的稳定性和收敛速度。

函数接受一个参数 norm_type,根据这个参数的值返回不同类型的归一化层。具体来说:

  • 如果 norm_type 的值为 'batch',则返回批量归一化层,并设置 affine 参数为 True
  • 如果 norm_type 的值为 'instance',则返回实例归一化层,并设置 affine 参数为 False
  • 如果 norm_type 的值为 'synBN',则返回一个自定义的 SynBN2d 归一化层,该归一化层也设置 affine 参数为 True
  • 如果 norm_type 的值不是上述任何一种,则抛出 NotImplementedError 异常,表示未找到指定类型的归一化层。

⑤ define_G

def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False,
             opt=None):
    # 定义生成器(全局生成器或局部增强生成器)和特征编码器
    netG = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())

    if which_model_netG == 'resnet_9blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
                               gpu_ids=gpu_ids)
    elif which_model_netG == 'resnet_6blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
                               gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_128':
        netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_256':
        netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             gpu_ids=gpu_ids, skip=skip, opt=opt)
    elif which_model_netG == 'unet_512':
        netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             gpu_ids=gpu_ids, skip=skip, opt=opt)
    elif which_model_netG == 'sid_unet':
        netG = Unet(opt, skip)
    elif which_model_netG == 'sid_unet_shuffle':
        netG = Unet_pixelshuffle(opt, skip)
    elif which_model_netG == 'sid_unet_resize':
        netG = Unet_resize_conv(opt, skip)
    elif which_model_netG == 'DnCNN':
        netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
    if len(gpu_ids) >= 0:
        netG.cuda(device=gpu_ids[0])
        netG = torch.nn.DataParallel(netG, gpu_ids)
    netG.apply(weights_init)
    return netG

这段代码主要作用是定义了一个生成器网络的创建函数 define_G。这个函数根据指定的参数创建不同类型的生成器网络,支持的生成器类型包括 ResNet 生成器、U-Net 生成器等。此外,函数也支持在 GPU 上运行,并对生成器进行权重初始化。

主要参数:

  • input_nc输入通道数。
  • output_nc输出通道数。
  • ngf生成器中特征图的数量。
  • which_model_netG选择的生成器模型的名称。
  • norm归一化层的类型('batch'、'instance'等)。
  • use_dropout是否使用 dropout。
  • gpu_ids指定在哪些 GPU 上运行。
  • skip是否使用 skip connection(跳跃连接)。
  • opt其他选项,可能用于某些生成器类型的参数设置。

函数首先根据输入的 which_model_netG 参数选择相应的生成器模型。然后,根据其他参数,如归一化类型、是否使用 dropout 等,构建生成器。最后,将生成器应用权重初始化,如果指定了 GPU,将其移动到 GPU 上,并进行 DataParallel 包装。


⑥ define_D

def define_D(input_nc, ndf, which_model_netD,
             n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False):
    # 定义多层鉴别器
    netD = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())
    if which_model_netD == 'basic':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
                                   gpu_ids=gpu_ids)
    elif which_model_netD == 'n_layers':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
                                   gpu_ids=gpu_ids)
    elif which_model_netD == 'no_norm':
        netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
    elif which_model_netD == 'no_norm_4':
        netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
    elif which_model_netD == 'no_patchgan':
        netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' %
                                  which_model_netD)
    if use_gpu:
        netD.cuda(device=gpu_ids[0])
        netD = torch.nn.DataParallel(netD, gpu_ids)
    netD.apply(weights_init)
    return netD

这段代码主要作用是定义了一个判别器网络的创建函数 define_D。这个函数根据指定的参数创建不同类型的判别器网络,支持的判别器类型包括基础的多层判别器、带有 n 层的判别器、无归一化的判别器等。

主要参数:

  • input_nc:输入通道数。
  • ndf:判别器中特征图的数量。
  • which_model_netD:选择的判别器模型的名称。
  • n_layers_D:判别器的层数。
  • norm:归一化层的类型('batch'、'instance'等)。
  • use_sigmoid:是否使用 Sigmoid 函数作为激活函数。
  • gpu_ids:指定在哪些 GPU 上运行。
  • patch:是否使用 patchGAN 结构。

函数首先根据输入的 which_model_netD 参数选择相应的判别器模型。然后,根据其他参数,如归一化类型、是否使用 Sigmoid 等,构建判别器。最后,将判别器应用权重初始化,如果指定了 GPU,将其移动到 GPU 上,并进行 DataParallel 包装。


⑦ print_network

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

这段代码主要作用是用于打印神经网络的结构信息和总参数数量。


🎄2.2 Class

①class GANLoss

class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label # 真实标签为1
        self.fake_label = target_fake_label # 虚假标签为0
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan: # 是否使用lsgan的loss损失
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real): # 获取目标标签张量
        target_tensor = None
        if target_is_real: # 表示获取真实标签的目标张量
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
             # 创建一个形状与输入相同的张量,
             # 并填充为真实标签值,
             # 然后将其封装为不可训练的 PyTorch 变量 Variable,
             # 并赋值给 self.real_label_var。
             # 最终,返回真实标签变量 self.real_label_var。
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else: # 表示获取生成标签的目标张量
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
            # 创建一个形状与输入相同的张量,
            # 并填充为生成标签值,
            # 然后将其封装为不可训练的 PyTorch 变量Variable,
            # 并赋值给 self.fake_label_var。
            # 最终,返回生成标签变量 self.fake_label_var 。
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

这段代码主要作用是定义了一个 GAN 损失的类 GANLoss,用于计算生成对抗网络 (GAN) 的生成器和判别器的损失

主要参数:

  • use_lsgan一个布尔值,表示是否使用均方误差损失(True)还是二进制交叉熵损失(False)。
  • target_real_label真实标签的目标值,默认为1.0。
  • target_fake_label生成标签的目标值,默认为0.0。
  • tensor用于创建标签张量的 PyTorch 张量类型,默认为torch.FloatTensor

主要方法和属性包括:

  • loss根据 use_lsgan 初始化的时候选择使用 MSELoss 还是 BCELoss。
  • get_target_tensor用于获取目标标签张量,根据 target_is_real 和类内部的真假标签值。
  • __call__计算 GAN 损失,传入输入张量 input 和一个布尔值 target_is_real,表示是否计算真实标签的损失。

② class DiscLossWGANGP

class DiscLossWGANGP():
    def __init__(self):
        self.LAMBDA = 10

    def name(self):
        return 'DiscLossWGAN-GP'

    def initialize(self, opt, tensor):
        # DiscLossLS.initialize(self, opt, tensor)
        self.LAMBDA = 10

    # def get_g_loss(self, net, realA, fakeB):
    #     # First, G(A) should fake the discriminator
    #     self.D_fake = net.forward(fakeB)
    #     return -self.D_fake.mean()

    def calc_gradient_penalty(self, netD, real_data, fake_data):
        alpha = torch.rand(1, 1)
        alpha = alpha.expand(real_data.size())
        alpha = alpha.cuda()

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        interpolates = interpolates.cuda()
        interpolates = Variable(interpolates, requires_grad=True)

        disc_interpolates = netD.forward(interpolates)

        gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                        grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
                                        create_graph=True, retain_graph=True, only_inputs=True)[0]

        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
        return gradient_penalty

这段代码主要作用是定义了一个用于计算 Wasserstein GAN with Gradient Penalty (WGAN-GP) 损失的类 DiscLossWGANGP

主要的方法和属性包括:

  • __init__构造函数,初始化 LAMBDA 参数,该参数用于控制渐变惩罚的强度,默认为10。
  • name返回损失的名称,这里为 'DiscLossWGAN-GP'。
  • initialize初始化方法,用于设定一些参数。在这里,对 LAMBDA 进行了重新设置为10。
  • calc_gradient_penalty计算渐变惩罚项的方法。该方法接受判别器网络 netD、真实数据 real_data 和生成数据 fake_data 作为输入。首先,通过插值方法创建一个介于真实数据和生成数据之间的样本集合。然后,计算这些插值样本通过判别器的输出,并计算相对于插值样本的梯度。最终,计算渐变惩罚项,即梯度的范数减1的平方的均值乘以 LAMBDA 参数。

③ class ResnetGenerator

class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6,
                 gpu_ids=[], padding_type='reflect'):
        assert (n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.gpu_ids = gpu_ids

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [
                ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

这段代码主要作用是定义一个生成器网络类 ResnetGenerator,用于实现带残差块的生成器结构。生成器的主要目标是将输入图像转换为目标域的图像。

主要参数和方法包括:

  • __init__构造函数,定义了生成器的结构。接受一系列参数,包括输入通道数 input_nc,输出通道数 output_nc,生成器的特征数 ngf,规范化层 norm_layer,是否使用 dropout use_dropout,残差块的数量 n_blocks,GPU 设备的列表 gpu_ids 以及填充类型 padding_type

  • forward前向传播方法,将输入张量通过生成器网络进行转换。在这里,根据是否使用 GPU,选择在单个 GPU 上运行或在多个 GPU 上并行运行。

生成器的网络结构包括:

  1. 一个反射填充层 (ReflectionPad2d),将输入图像进行填充。
  2. 一个卷积层 (Conv2d),将填充后的输入映射到特征图,使用 ReLU 激活函数。
  3. 一系列下采样层 (Conv2d,规范化层,ReLU 激活函数),通过多个下采样层减小特征图的大小。
  4. 一系列残差块 (ResnetBlock),通过多个残差块学习图像的细节和结构。
  5. 一系列上采样层 (ConvTranspose2d,规范化层,ReLU 激活函数),通过多个上采样层增加特征图的大小。
  6. 一个反射填充层 (ReflectionPad2d)。
  7. 一个卷积层 (Conv2d),将最终的特征图映射到输出通道。
  8. Tanh 激活函数,将输出限制在 -1 到 1 的范围内。

④  class ResnetBlock

class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim),
                       nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

这段代码主要作用是定义 ResNet 块的类 ResnetBlock,用于构建生成器中的残差连接块。每个 ResNet 块包含两个卷积层,每个卷积层后跟着归一化层和 ReLU 激活函数。


⑤ class UnetGenerator

class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None):
        super(UnetGenerator, self).__init__()
        self.gpu_ids = gpu_ids
        self.opt = opt
        # currently support only input_nc == output_nc
        assert (input_nc == output_nc)

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer,
                                                 use_dropout=use_dropout, opt=opt)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt)

        if skip == True:
            skipmodule = SkipModule(unet_block, opt)
            self.model = skipmodule
        else:
            self.model = unet_block

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

这段代码主要作用是实现 U-Net 生成器的类 UnetGenerator,用于图像到图像的转换任务

U-Net 生成器的结构包括:

  1. 通过堆叠多个 UnetSkipConnectionBlock 模块实现 U-Net 结构。
  2. 对于每个下采样,都使用 UnetSkipConnectionBlock 模块进行堆叠。
  3. 最终的输出通道数为 output_nc

如果设置了 skip 参数为 True,则会使用 SkipModule 对 U-Net 结构进行进一步的封装。


⑥ class SkipModule

class SkipModule(nn.Module):
    def __init__(self, submodule, opt):
        super(SkipModule, self).__init__()
        self.submodule = submodule
        self.opt = opt

    def forward(self, x):
        latent = self.submodule(x)
        return self.opt.skip * x + latent, latent

这段代码主要作用是通过SkipModule 模块添加跳跃连接


⑦ class UnetSkipConnectionBlock

class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,
                 opt=None):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost

        downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if opt.use_norm == 0:
            if outermost:
                upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                            kernel_size=4, stride=2,
                                            padding=1)
                down = [downconv]
                up = [uprelu, upconv, nn.Tanh()]
                model = down + [submodule] + up
            elif innermost:
                upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                            kernel_size=4, stride=2,
                                            padding=1)
                down = [downrelu, downconv]
                up = [uprelu, upconv]
                model = down + up
            else:
                upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                            kernel_size=4, stride=2,
                                            padding=1)
                down = [downrelu, downconv]
                up = [uprelu, upconv]

                if use_dropout:
                    model = down + [submodule] + up + [nn.Dropout(0.5)]
                else:
                    model = down + [submodule] + up
        else:
            if outermost:
                upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                            kernel_size=4, stride=2,
                                            padding=1)
                down = [downconv]
                up = [uprelu, upconv, nn.Tanh()]
                model = down + [submodule] + up
            elif innermost:
                upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                            kernel_size=4, stride=2,
                                            padding=1)
                down = [downrelu, downconv]
                up = [uprelu, upconv, upnorm]
                model = down + up
            else:
                upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                            kernel_size=4, stride=2,
                                            padding=1)
                down = [downrelu, downconv, downnorm]
                up = [uprelu, upconv, upnorm]

                if use_dropout:
                    model = down + [submodule] + up + [nn.Dropout(0.5)]
                else:
                    model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([self.model(x), x], 1)

这段代码主要作用是通过UnetSkipConnectionBlock 模块构建 U-Net 中的下采样和上采样块。它可以包含子模块,并具有跳跃连接。

主要参数:

  • outer_nc:  输出通道数。
  • inner_nc:  内部通道数。
  • submodule:  可选的子模块。
  • outermost:  是否为最外层模块。
  • innermost:  是否为最内层模块。
  • norm_layer:  规范化层的类型。
  • use_dropout:  是否使用 dropout。
  • opt 一些其他选项。

该模块包含以下组件:

  • 下采样(卷积、LeakyReLU、规范化)。
  • 子模块(如果存在)。
  • 上采样(ReLU、转置卷积、Tanh)。

⑧ class NLayerDiscriminator

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
        super(NLayerDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids

        kw = 4
        padw = int(np.ceil((kw - 1) / 2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
        #     return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        # else:
        return self.model(input)

这段代码主要作用是通过NLayerDiscriminator 多层鉴别器模块判别输入图像的真实性。它包含多个卷积层,每一层都包括卷积、规范化和 LeakyReLU 激活函数。

主要参数:

  • input_nc 输入通道数。
  • ndf初始卷积层的输出通道数。
  • n_layers 鉴别器包含的卷积层的数量。
  • norm_layer规范化层的类型。
  • use_sigmoid 是否在输出层使用 Sigmoid 激活函数。
  • gpu_ids:  GPU 的 ID 列表。

该模块的结构包括:

  1. 初始卷积层:输入图像经过一个卷积层,然后应用 LeakyReLU 激活函数。
  2. 多个卷积块:每个卷积块包括卷积层、规范化层和 LeakyReLU 激活函数。这些卷积块用于逐渐降低特征图的空间分辨率。
  3. 最终卷积层:最后一个卷积块后有一个额外的卷积层,用于生成最终的鉴别输出。
  4. Sigmoid 激活函数(可选):如果 use_sigmoidTrue,则在最后添加 Sigmoid 激活函数。

⑨ class NoNormDiscriminator

class NoNormDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
        super(NoNormDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids

        kw = 4
        padw = int(np.ceil((kw - 1) / 2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
        #     return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        # else:
        return self.model(input)

NoNormDiscriminator 是一个没有规范化层的鉴别器模块。它与 NLayerDiscriminator 的区别在于去除了规范化层,每个卷积层后面直接接 LeakyReLU 激活函数。


 ⑩ class FCDiscriminator

class FCDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False):
        super(FCDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids
        self.use_sigmoid = use_sigmoid
        kw = 4
        padw = int(np.ceil((kw - 1) / 2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
        if patch:
            self.linear = nn.Linear(7 * 7, 1)
        else:
            self.linear = nn.Linear(13 * 13, 1)
        if use_sigmoid:
            self.sigmoid = nn.Sigmoid()

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        batchsize = input.size()[0]
        output = self.model(input)
        output = output.view(batchsize, -1)
        # print(output.size())
        output = self.linear(output)
        if self.use_sigmoid:
            print("sigmoid")
            output = self.sigmoid(output)
        return output

FCDiscriminator 是一个基于卷积神经网络的鉴别器模块,用于图像分类任务。它的主要特点是可以根据 patch 参数选择输出全局分类还是局部分类。


 ⑪ class Unet_resize_conv

class Unet_resize_conv(nn.Module):
    def __init__(self, opt, skip):
        super(Unet_resize_conv, self).__init__()

        self.opt = opt
        self.skip = skip
        p = 1
        # self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
        if opt.self_attention:
            self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
            # self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
            self.downsample_1 = nn.MaxPool2d(2)
            self.downsample_2 = nn.MaxPool2d(2)
            self.downsample_3 = nn.MaxPool2d(2)
            self.downsample_4 = nn.MaxPool2d(2)
        else:
            self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
        self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
        self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
        self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p)
        self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
        self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
        self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p)
        self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
        self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
        self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p)
        self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
        self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p)
        self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
        self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p)
        self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p)
        self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)

        # self.deconv5 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.deconv5 = nn.Conv2d(512, 256, 3, padding=p)
        self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p)
        self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
        self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p)
        self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)

        # self.deconv6 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.deconv6 = nn.Conv2d(256, 128, 3, padding=p)
        self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p)
        self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
        self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)

        # self.deconv7 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.deconv7 = nn.Conv2d(128, 64, 3, padding=p)
        self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p)
        self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
        self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)

        # self.deconv8 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.deconv8 = nn.Conv2d(64, 32, 3, padding=p)
        self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p)
        self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True)
        if self.opt.use_norm == 1:
            self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
        self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True)

        self.conv10 = nn.Conv2d(32, 3, 1)
        if self.opt.tanh:
            self.tanh = nn.Tanh()

    def depth_to_space(self, input, block_size):
        block_size_sq = block_size * block_size
        output = input.permute(0, 2, 3, 1)
        (batch_size, d_height, d_width, d_depth) = output.size()
        s_depth = int(d_depth / block_size_sq)
        s_width = int(d_width * block_size)
        s_height = int(d_height * block_size)
        t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth)
        spl = t_1.split(block_size, 3)
        stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl]
        output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).resize(batch_size, s_height, s_width,
                                                                                     s_depth)
        output = output.permute(0, 3, 1, 2)
        return output

    def forward(self, input, gray):
        flag = 0
        if input.size()[3] > 2200:
            avg = nn.AvgPool2d(2)
            input = avg(input)
            gray = avg(gray)
            flag = 1
            # pass
        input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input)
        gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray)
        if self.opt.self_attention:
            gray_2 = self.downsample_1(gray)
            gray_3 = self.downsample_2(gray_2)
            gray_4 = self.downsample_3(gray_3)
            gray_5 = self.downsample_4(gray_4)
        if self.opt.use_norm == 1:
            if self.opt.self_attention:
                x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))))
                # x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
            else:
                x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
            conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x)))
            x = self.max_pool1(conv1)

            x = self.bn2_1(self.LReLU2_1(self.conv2_1(x)))
            conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x)))
            x = self.max_pool2(conv2)

            x = self.bn3_1(self.LReLU3_1(self.conv3_1(x)))
            conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x)))
            x = self.max_pool3(conv3)

            x = self.bn4_1(self.LReLU4_1(self.conv4_1(x)))
            conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
            x = self.max_pool4(conv4)

            x = self.bn5_1(self.LReLU5_1(self.conv5_1(x)))
            x = x * gray_5 if self.opt.self_attention else x
            conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x)))

            conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
            conv4 = conv4 * gray_4 if self.opt.self_attention else conv4
            up6 = torch.cat([self.deconv5(conv5), conv4], 1)
            x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6)))
            conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x)))

            conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
            conv3 = conv3 * gray_3 if self.opt.self_attention else conv3
            up7 = torch.cat([self.deconv6(conv6), conv3], 1)
            x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7)))
            conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x)))

            conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
            conv2 = conv2 * gray_2 if self.opt.self_attention else conv2
            up8 = torch.cat([self.deconv7(conv7), conv2], 1)
            x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8)))
            conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x)))

            conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
            conv1 = conv1 * gray if self.opt.self_attention else conv1
            up9 = torch.cat([self.deconv8(conv8), conv1], 1)
            x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9)))
            conv9 = self.LReLU9_2(self.conv9_2(x))

            latent = self.conv10(conv9)

            if self.opt.times_residual:
                latent = latent * gray

            # output = self.depth_to_space(conv10, 2)
            if self.opt.tanh:
                latent = self.tanh(latent)
            if self.skip:
                if self.opt.linear_add:
                    if self.opt.latent_threshold:
                        latent = F.relu(latent)
                    elif self.opt.latent_norm:
                        latent = (latent - torch.min(latent)) / (torch.max(latent) - torch.min(latent))
                    input = (input - torch.min(input)) / (torch.max(input) - torch.min(input))
                    output = latent + input * self.opt.skip
                    output = output * 2 - 1
                else:
                    if self.opt.latent_threshold:
                        latent = F.relu(latent)
                    elif self.opt.latent_norm:
                        latent = (latent - torch.min(latent)) / (torch.max(latent) - torch.min(latent))
                    output = latent + input * self.opt.skip
            else:
                output = latent

            if self.opt.linear:
                output = output / torch.max(torch.abs(output))


        elif self.opt.use_norm == 0:
            if self.opt.self_attention:
                x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))
            else:
                x = self.LReLU1_1(self.conv1_1(input))
            conv1 = self.LReLU1_2(self.conv1_2(x))
            x = self.max_pool1(conv1)

            x = self.LReLU2_1(self.conv2_1(x))
            conv2 = self.LReLU2_2(self.conv2_2(x))
            x = self.max_pool2(conv2)

            x = self.LReLU3_1(self.conv3_1(x))
            conv3 = self.LReLU3_2(self.conv3_2(x))
            x = self.max_pool3(conv3)

            x = self.LReLU4_1(self.conv4_1(x))
            conv4 = self.LReLU4_2(self.conv4_2(x))
            x = self.max_pool4(conv4)

            x = self.LReLU5_1(self.conv5_1(x))
            x = x * gray_5 if self.opt.self_attention else x
            conv5 = self.LReLU5_2(self.conv5_2(x))

            conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
            conv4 = conv4 * gray_4 if self.opt.self_attention else conv4
            up6 = torch.cat([self.deconv5(conv5), conv4], 1)
            x = self.LReLU6_1(self.conv6_1(up6))
            conv6 = self.LReLU6_2(self.conv6_2(x))

            conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
            conv3 = conv3 * gray_3 if self.opt.self_attention else conv3
            up7 = torch.cat([self.deconv6(conv6), conv3], 1)
            x = self.LReLU7_1(self.conv7_1(up7))
            conv7 = self.LReLU7_2(self.conv7_2(x))

            conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
            conv2 = conv2 * gray_2 if self.opt.self_attention else conv2
            up8 = torch.cat([self.deconv7(conv7), conv2], 1)
            x = self.LReLU8_1(self.conv8_1(up8))
            conv8 = self.LReLU8_2(self.conv8_2(x))

            conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
            conv1 = conv1 * gray if self.opt.self_attention else conv1
            up9 = torch.cat([self.deconv8(conv8), conv1], 1)
            x = self.LReLU9_1(self.conv9_1(up9))
            conv9 = self.LReLU9_2(self.conv9_2(x))

            latent = self.conv10(conv9)

            if self.opt.times_residual:
                latent = latent * gray

            if self.opt.tanh:
                latent = self.tanh(latent)
            if self.skip:
                if self.opt.linear_add:
                    if self.opt.latent_threshold:
                        latent = F.relu(latent)
                    elif self.opt.latent_norm:
                        latent = (latent - torch.min(latent)) / (torch.max(latent) - torch.min(latent))
                    input = (input - torch.min(input)) / (torch.max(input) - torch.min(input))
                    output = latent + input * self.opt.skip
                    output = output * 2 - 1
                else:
                    if self.opt.latent_threshold:
                        latent = F.relu(latent)
                    elif self.opt.latent_norm:
                        latent = (latent - torch.min(latent)) / (torch.max(latent) - torch.min(latent))
                    output = latent + input * self.opt.skip
            else:
                output = latent

            if self.opt.linear:
                output = output / torch.max(torch.abs(output))

        output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom)
        latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom)
        gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom)
        if flag == 1:
            output = F.upsample(output, scale_factor=2, mode='bilinear')
            gray = F.upsample(gray, scale_factor=2, mode='bilinear')
        if self.skip:
            return output, latent
        else:
            return output

这段代码主要作用是定义了一个Unet_resize_conv类,用于图像处理任务的深度学习模型,通常用于图像分割等任务。

代码的主要结构和功能:

  1. 初始化函数 (__init__):

  • 接受两个参数optskip
  • 初始化模型的一些参数,包括选择是否使用自注意力机制(opt.self_attention)、是否使用归一化(opt.use_norm)、是否使用平均池化(opt.use_avgpool)等。
  1. 前向传播函数 (forward):

  • 接受两个输入张量inputgray
  • 根据输入的配置参数进行一系列卷积、激活函数、归一化等操作,构建了一个 U-Net 结构的神经网络。
  • 根据模型配置选择是否使用自注意力机制。
  • 最终输出图像结果。
  1. 深度到空间函数 (depth_to_space):

  • 用于将深度张量转换为空间张量。通常在图像处理任务中,将高分辨率图像转换为低分辨率图像时使用。
  1. 一些辅助函数:

  • 例如,对输入进行填充(pad_tensor)和反向填充(pad_tensor_back)等。

剩下非重点的就不再解读了~


🚀三、EnlightenGAN源码运行

在本文最上面已经放了项目地址,作者给出了源码,数据集等,这些都可以在里面下载到,ReadMe中也给出了详细的运行方法,对小白来说还是比较友好的。

我跑的过程没记录,哈哈~

这块网上有很多博主讲解的比较详细,大家可以参考一下:

EnlightenGAN训练复现记录_enlightengan代码复现-CSDN博客

代码调试记录EnlightenGAN 一_代码调试记录怎么写-CSDN博客

EnlightenGAN的运行环境搭建和训练自己的数据 - 知乎 (zhihu.com)

EnlightenGAN的代码运行过程问题记录_enlightengan运行不了-CSDN博客

踩坑记录:

EnlightenGAN: Deep Light Enhancement without Paired Supervision源码实现_./final_dataset/traina is not a valid directory-CSDN博客

EnlightenGAN代码复现错误总结-CSDN博客

实现效果:

可以看到,增强效果还是不错滴~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1372105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何降低成本,制作个性化电子产品宣传册呢

​随着科技的飞速发展,电子产品已经深入到我们生活的每一个角落。然而,如何让你的产品在众多竞争者中脱颖而出呢?制作一份个性化的宣传册,不仅可以吸引潜在客户,还能有效降低成本,提升销售效果。 一、明确目…

基于JAVA+SpringBoot的高校学术报告系统

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍: 智慧高校学术报告系统…

职场日常英语口语,成人英语培训学校,柯桥学英语推荐哪里

“玩手机”用英语怎么说?你的第一反应是不是:play the phone? 在英语中,play这个动词通常表示“玩耍、娱乐、操纵”等意思,而手机是一种工具,不是玩耍的对象。 换句话说,我们“玩手机”&#xf…

CUDA编程:执行模型

SM 在SM中,共享内存和寄存器是非常重要的资源。共享内存被分配在SM上的常驻线程 块中,寄存器在线程中被分配。线程块中的线程通过这些资源可以进行相互的合作和通 信。 WARP CUDA采用单指令多线程(SIMT)架构来管理和执行线程&am…

机器学习中的隐马尔可夫模型及Python实现示例

隐马尔可夫模型(HMM)是一种统计模型,用于描述观测序列和隐藏状态序列之间的概率关系。它通常用于生成观测值的底层系统或过程未知或隐藏的情况,因此它被称为“隐马尔可夫模型”。 它用于根据生成数据的潜在隐藏过程来预测未来的观…

跟我学java|Stream流式编程——并行流

什么是并行流 并行流是 Java 8 Stream API 中的一个特性。它可以将一个流的操作在多个线程上并行执行,以提高处理大量数据时的性能。 在传统的顺序流中,所有的操作都是在单个线程上按照顺序执行的。而并行流则会将流的元素分成多个小块,并在多…

【Java集合篇】 ConcurrentHashMap在哪些地方做了并发控制

ConcurrentHashMap在哪些地方做了并发控制 ✅典型解析✅初始化桶阶段🟢桶满了会自动扩容吗🟠自动扩容的时间频率是多少 ✅put元素阶段✅扩容阶段🟠 拓展知识仓🟢ConcurrentSkipListMap和ConcurrentHashMap有什么区别☑️简单介绍一…

2024年第九届计算机与通信系统国际会议(ICCCS2024) ,邀您相约西安!

会议官网: ICCCS2024 | Xian China 时间: 2024年4月19-22日 地点: 中国西安 会议简介: 近年来,信息通信在不断发展,为计算机网络的进步与发展提供了先进可靠的技术支持。随着计算机网络与通信技术的深入发展,计算机通信技术、数…

报错解决:RuntimeError: Error building extension ‘bias_act_plugin‘

系统: Ubuntu22.04, nvcc -V:11.8 , torch:2.0.0cu118 一:BUG内容 运行stylegan项目的train.py时遇到报错👇 Setting up PyTorch plugin "bias_act_plugin"... Failed! /home/m…

docker+jmeter实现windows作为主控机,linux作为负载机的分布式压测环境搭建

dockerjmeter实现windows作为主控机,linux作为负载机的分布式压测环境搭建 1、搭建环境说明2、windows主控机安装Jmeter3、linux负载机安装Jmeter3.1、安装docker环境3.2、使用docker安装jmeter 4、windows主控机分发测试任务 1、搭建环境说明 准备一台windows主机…

京东(天猫淘宝)数据分析工具-鲸参谋系统全功能解析——行业大盘、红蓝海市场、品牌分析、店铺分析、商品分析、竞品监控(区分自营和POP)

作为第三方电商数据平台,鲸参谋电商大数据系统能够为品牌方和商家提供包括行业趋势、热门品牌、店铺分析、单品分析在内的多个层面数据分析,帮助商家做出更加准确的经营决策,提升经营效率,实现精准营销。 下面,我们针…

YOLOv8优化策略:轻量化改进 | 超越RepVGG!浙大阿里提出OREPA:在线卷积重参数化

🚀🚀🚀本文改进:在线卷积重参数化巧妙的和YOLOV8结合,并实现轻量化 🚀🚀🚀YOLOv8改进专栏:http://t.csdnimg.cn/hGhVK 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 1.OREPA介绍 论文:https://arxiv.org/pdf/2204.00826.pdf 摘要:结构重新参数化在…

软件测试|Python Selenium 库安装使用指南

简介 Selenium 是一个用于自动化浏览器操作的强大工具,它可以模拟用户在浏览器中的行为,例如点击、填写表单、导航等。在本指南中,我们将详细介绍如何安装和使用 Python 的 Selenium 库。 安装 Selenium 库 使用以下命令可以通过 pip 安装…

内网渗透之CobaltStrike(CS)

目录 一、Cobalt Strike简介 二、Cobalt Strike基本用法 1、启动服务端 2、客户端连接 3、设置监听器(Listeners) 4、脚本管理器(Script Manager) 5、攻击(最常用的是生成后门) 6、CS上线 7、Beaco…

图神经网络 7大高效创新思路分享,附17篇最新顶会论文和代码

2024年了,图神经网络方向还好发论文吗?答案当然是能。 图神经网络在处理非欧空间数据和复杂特征方面具有明显的优势,且已成为了深度学习领域的热点,在学术界和工业界都有着广泛的研究和应用。不仅如此,图神经网络与CV…

如何在集简云中调用GPTs(Assistant) API

我们在OpenAI中创建了GPTs(Assistant)后,希望放到其它软件中使用,比如 抖音私信,抖音评论,微信公众号,钉钉,飞书,企业微信...... 要如何实现这样的功能呢? 您可以使用集简云的 “数…

科研上新 | 第3期:大模型推进科研边界;大模型的道德价值对齐;优化动态稀疏深度学习模型;十亿规模向量搜索的高效更新

者按:欢迎阅读“科研上新”栏目!“科研上新”汇聚了微软亚洲研究院最新的创新成果与科研动态。在这里,你可以快速浏览研究院的亮点资讯,保持对前沿领域的敏锐嗅觉,同时也能找到先进实用的开源工具。 本期内容速览 01…

JavaScript日期和时间处理手册

🧑‍🎓 个人主页:《爱蹦跶的大A阿》 🔥当前正在更新专栏:《VUE》 、《JavaScript保姆级教程》、《krpano》 ​ ​ ✨ 前言 日期和时间在应用开发中是非常常用的功能。本文将全面介绍JavaScript中处理日期和时间的方…

vscode设置python脚本运行参数

1 添加配置文件 点击到你要配置的python文件,然后右上角点击 运行 ,再点击 添加配置 再点击 “Pyhton文件” 选项(其实就是在选择 当前的python文件 进行配置) 接着就生成了配置文件 lanunch.json 2 参数配置 再上面代码的基础上…

1月10号代码随想录左叶子之和

404.左叶子之和 给定二叉树的根节点 root ,返回所有左叶子之和。 示例 1: 输入: root [3,9,20,null,null,15,7] 输出: 24 解释: 在这个二叉树中,有两个左叶子,分别是 9 和 15,所以返回 24示例 2: 输入: root [1]…