gan, pixel2pixel, cyclegan, srgan图像超分辨率

news2025/3/11 0:41:29

文章目录

    • 1.gan
    • 2.DCgan
    • 3.cgan
    • 4.pixel2pixel(Image-to-Image Translation with Conditional Adversarial Networks)
    • 5.CycleGAN
    • 6.Deep learning for in vivo near-infrared imaging
    • 11..Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial (srgan, srresnet) (2017)
      • 11.1. 一篇经典的超分论文。
      • 11.2. 网络结构
      • 11.3.关于训练
    • 12.ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
    • 13.GAN
    • 14. dasr oppo

1.gan

通俗理解生成对抗网络GAN
对抗生成网络GAN系列——GAN原理及手写数字生成小案例
就是随机生成噪声,假如128维度,Gnet 输出 28x28的图像
Dnet输出label,1或者0 , 二分类网络。

判别器就是 输入真实图 分类为1
输入生成图 分类为0

生成器就是 希望输入生成图到判别器,分类为1.

注意这里的网络模型不能保证生成的数字到底是几,给定一个随机噪声,生成的数字可能是0-9
在这里插入图片描述

或者
在这里插入图片描述

2.DCgan

这里主要是更改了一些生成器和判别器的结构,比如用卷积替换全连接,假如batchnorm等,提升生成的效果。
后续可以使用UNet等进一步提升。
在这里插入图片描述

https://zhuanlan.zhihu.com/p/35983991 生成对抗网络系列(3)——cGAN及图像条件 这一系列博客写的也很好。

3.cgan

Conditional Generative Adversarial Nets,即条件生成对抗网络。
就是通过添加限制条件,来控制GAN生成数据的特征(类别),比如之前我们的随机噪声可以生成数字0-9但是我们并不能控制生成的是0还是1,还是2.
在这里插入图片描述

这里要把类别标签一起输入到网络。
另外损失函数没有采用二分类交叉熵,而是使用mse.
在这里插入图片描述

https://zhuanlan.zhihu.com/p/302720602

这里分析一下其原理:
gan之所以有效,只凭了三个损失函数:
fake(gen) 输入判别器 得到0
real 输入判别器 得到1
那么判别器学到了 什么是0,什么是1:即 生成的图像是 0,real图是1
噪声z 输入生成器,希望判别器得到 1, 即希望生成器生成的图 输入判别器时 是 1,即希望生成器生成的图,和real更接近。

CGAN 加入了类别label, label的形式可以是0-N的数字,也可以是one-hot编码, 也可是 和 噪声z同维度的一个tensor。
损失函数仍然是三个。
希望 噪声z+ 类别label 输入 生成器后 得到该label对应的图像。

4.pixel2pixel(Image-to-Image Translation with Conditional Adversarial Networks)

是cgan的一种,只不过输入的不是噪声,输入的是一些hint提示,理所应当比cgan效果好才对。
https://www.jianshu.com/p/066e2c274887

看代码很清晰:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
pixel2pixel是一种图像转换,不是从噪声直接生成的。
在这里插入图片描述

特点就是

  1. 不是从噪声直接生成图像,而是从某一类图像转换为零一类图像。假如从噪声图转换为无噪声图是否可以,也是可以的呀。

  2. 判别器的2个损失函数和之前的是类似的,就是判别真假。只是pixel2pixel中不是得到一个数字作为lable而是一个矩阵求平均,其实差异也不大。

  3. 那么生成器呢,除了原来的损失,再加上一个L1损失。这是理所应当的。作者实验假如不利用gan,只有L1来损失,这其实就是一个简单的图像转换网络,发现不清晰,缺少高频,再加上cgan 图像更生动清晰。 想想srgan就是gan在超分中的应用。
    在这里插入图片描述

https://aistudio.baidu.com/projectdetail/1119048

5.CycleGAN

https://cloud.tencent.com/developer/article/1064970
pix2pix是用GAN解决image-to-image translation的开山之作,他的主要思路就是用成对的图像(paired image)去训练生成器和判别器,最后向训练好的生成器输入图片就可以得到目标图片(aim image)
在这里插入图片描述

看下图
在这里插入图片描述

相比于pixel2pixel具体是如何改进的呢?
第一个理解:
上图的左上部分如下就是1个 gan, gan生成目标B, 但是没有label条件约束,因此pixel2pixel中的L1损失就没法使用了,那么如何保持生成的图像目标图像的一致性呢? 加上右边的网络和 cycle consistency lose.

第二个理解:
首先是重建网络重建A,然后重建网络中间的输出建立一个gan损失,是生成的图像符合目标B的风格.
在这里插入图片描述

在这里插入图片描述

参考:https://zhuanlan.zhihu.com/p/38752336

6.Deep learning for in vivo near-infrared imaging

体内 红外一区 和 红外二区图像转换。
在这里插入图片描述

11…Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial (srgan, srresnet) (2017)

11.1. 一篇经典的超分论文。

作者提出两个网络:SRResNet 和 SRGAN。 SRResNet 的图像 psnr 和 ssim都比较高,但是细节不够生动。
SRGAN的psnr,ssim没有那么高,但是细节会更丰富。
关于论文和code 可以搜到很多,毕竟是经典方法。

在这里插入图片描述

11.2. 网络结构

srresnet 网络结构也是 srgan的生成器部分。
srgan的生成器是 srresnet, 判别器部分是vgg 类型的网络。

在这里插入图片描述

网络结构相对简单清晰


import torch
import torch.nn as nn
import math

class _Residual_Block(nn.Module):
    def __init__(self):
        super(_Residual_Block, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.in1 = nn.InstanceNorm2d(64, affine=True)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.in2 = nn.InstanceNorm2d(64, affine=True)

    def forward(self, x):
        identity_data = x
        output = self.relu(self.in1(self.conv1(x)))
        output = self.in2(self.conv2(output))
        output = torch.add(output,identity_data)
        return output 

class _NetG(nn.Module):
    def __init__(self):
        super(_NetG, self).__init__()

        self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        
        self.residual = self.make_layer(_Residual_Block, 16)

        self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_mid = nn.InstanceNorm2d(64, affine=True)

        self.upscale4x = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=False)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.conv_input(x))
        residual = out
        out = self.residual(out)
        out = self.bn_mid(self.conv_mid(out))
        out = torch.add(out,residual)
        out = self.upscale4x(out)
        out = self.conv_output(out)
        return out

class _NetD(nn.Module):
    def __init__(self):
        super(_NetD, self).__init__()

        self.features = nn.Sequential(
        
            # input is (3) x 96 x 96
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (64) x 96 x 96
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),            
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (64) x 96 x 96
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),            
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size. (64) x 48 x 48
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (128) x 48 x 48
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (256) x 24 x 24
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (256) x 12 x 12
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),            
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            # state size. (512) x 12 x 12
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),            
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True)
        self.fc1 = nn.Linear(512 * 6 * 6, 1024)
        self.fc2 = nn.Linear(1024, 1)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)

    def forward(self, input):

        out = self.features(input)

        # state size. (512) x 6 x 6
        out = out.view(out.size(0), -1)

        # state size. (512 x 6 x 6)
        out = self.fc1(out)

        # state size. (1024)
        out = self.LeakyReLU(out)

        out = self.fc2(out)
        out = self.sigmoid(out)
        return out.view(-1, 1).squeeze(1)

11.3.关于训练

  1. srresnet 的损失函数就是

在这里插入图片描述

训练的代码也比较常规。

  1. srgan的损失函数是有三部分组成
    除了上面的pixel-wise MSE loss, 还有 VGG-f loss(feature map的MSE loss),VGG-f将图片输入到直接训练好的模型VGG的特定层的feature map, 这个VGG的weight是不训练的,相当于一个特征提取器,区别于判别器的vgg网络:

在这里插入图片描述

对抗损失 训练判别器的时候有一个,训练生成器的时候有2个。

三个损失函数
第一步训练判别器

        # Transfer in-memory data to CUDA devices to speed up training
        gt = batch_data["gt"].to(device=srgan_config.device, non_blocking=True)
        lr = batch_data["lr"].to(device=srgan_config.device, non_blocking=True)

        # Set the real sample label to 1, and the false sample label to 0
        batch_size, _, height, width = gt.shape
        real_label = torch.full([batch_size, 1], 1.0, dtype=gt.dtype, device=srgan_config.device)
        fake_label = torch.full([batch_size, 1], 0.0, dtype=gt.dtype, device=srgan_config.device)

        # Start training the discriminator model
        # During discriminator model training, enable discriminator model backpropagation
        for d_parameters in d_model.parameters():
            d_parameters.requires_grad = True

        # Initialize the discriminator model gradients
        d_model.zero_grad(set_to_none=True)

        # Calculate the classification score of the discriminator model for real samples(计算 gt 的分数)
        gt_output = d_model(gt)
        d_loss_gt = adversarial_criterion(gt_output, real_label)
        # Call the gradient scaling function in the mixed precision API to
        # back-propagate the gradient information of the fake samples
        d_loss_gt.backward(retain_graph=True)

        # Calculate the classification score of the discriminator model for fake samples(计算 生成的sr 的分数)
        # Use the generator model to generate fake samples
        sr = g_model(lr)
        sr_output = d_model(sr.detach().clone())
        d_loss_sr = adversarial_criterion(sr_output, fake_label)
        # Call the gradient scaling function in the mixed precision API to
        # back-propagate the gradient information of the fake samples
        d_loss_sr.backward()

        # Calculate the total discriminator loss value
        d_loss = d_loss_gt + d_loss_sr

        # Improve the discriminator model's ability to classify real and fake samples
        d_optimizer.step()
        # Finish training the discriminator model

然后固定判别器

        # Start training the generator model
        # During generator training, turn off discriminator backpropagation
        for d_parameters in d_model.parameters():
            d_parameters.requires_grad = False

训练生成器,利用三个损失函数

# Initialize generator model gradients
        g_model.zero_grad(set_to_none=True)

        # Calculate the perceptual loss of the generator, mainly including pixel loss, feature loss and adversarial loss
        pixel_loss = srgan_config.pixel_weight * pixel_criterion(sr, gt)
        content_loss = srgan_config.content_weight * content_criterion(sr, gt)
        adversarial_loss = srgan_config.adversarial_weight * adversarial_criterion(d_model(sr), real_label)
        # Calculate the generator total loss value
        g_loss = pixel_loss + content_loss + adversarial_loss
        # Call the gradient scaling function in the mixed precision API to
        # back-propagate the gradient information of the fake samples
        g_loss.backward()

        # Encourage the generator to generate higher quality fake samples, making it easier to fool the discriminator
        g_optimizer.step()
        # Finish training the generator model

当然也可以先训练生成器,再训练判别器。反正两个也是交替训练的。
关于gan最常见的训练方式 查看 code 和
loss解释

12.ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks

ESRGAN是对SRGAN的改进:

  1. 去掉BN,网络的基本单元从基本的残差单元变为Residual-in-Residual Dense Block (RRDB)
  2. GAN网络改进为Relativistic average GAN (RaGAN);
  3. 改进感知域损失函数,使用激活前的VGG特征,这个改进会提供更尖锐的边缘和更符合视觉的结果。
  4. 首先训练常规模型,然后再训练GAN模型。 则通过插值生成器部分可以得到不同程度的超分模型,调节平滑度和细节丰富度

1 很好的解释

13.GAN

下面两篇升级版都是对 图像退化的改进。

Designing a Practical Degradation Model for Deep Blind Image Super-Resolution (ICCV, 2021, BSRGAN)
(https://github.com/vvictoryuki/BSRGAN_implementation) 对于实际图像效果很好

Real-ESRGAN: TrainingReal-World Blind Super-Resolution with Pure Synthetic Data
Real-ESRGAN: (https://zhuanlan.zhihu.com/p/401387995)
(https://zhuanlan.zhihu.com/p/542750836)

振铃线性:https://blog.csdn.net/fengye2two/article/details/79895542

14. dasr oppo

https://blog.csdn.net/tywwwww/article/details/128036503

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1495588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AI视野·今日NLP 自然语言处理论文速览 第八十二期】Tue, 5 Mar 2024

AI视野今日CS.NLP 自然语言处理论文速览 Tue, 5 Mar 2024 (showing first 100 of 175 entries) Totally 100 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Key-Point-Driven Data Synthesis with its Enhancement on Mathematica…

论文研读_多目标部署优化:无人机在能源高效无线覆盖中的应用(ImMOGWO)精简版

此篇文章为Multi-objective Deployment Optimization of UAVs for Energy-Efficient Wireless Coverage的论文学习笔记,只供学习使用,不作商业用途,侵权删除。并且本人学术功底有限如果有思路不正确的地方欢迎批评指正! 创新点 RD算法 混合…

少儿编程机器人技术开发公司的创新之路

行业背景,国家政策利好 随着科技的不断发展,少儿编程机器人技术作为一种新兴的教育方式逐渐受到人们的关注。这项技术将编程与机器人技术相结合,通过互动性强、趣味性高的方式,帮助儿童学习编程知识,培养逻辑思维和创…

使用J-Link Commander通过J-LINK以命令的形式来访问ARM通用MCU

通常我们的操作是写好程序然后将程序下载到芯片里面,然后运行程序来进行相应的操作,其实还可以使用 J − L i n k C o m m a n d e r J-Link\quad Commander J−LinkCommander通过 J − L I N K J-LINK J−LINK以命令的形式来简单访问ARM通用MCU&#xf…

IP劫持的危害及应对策略

随着互联网的发展,网络安全问题日益凸显,其中IP劫持作为一种常见的网络攻击手段,对个人和企业的信息安全造成了严重的威胁。IP数据云将分析IP劫持的危害,并提出相应的应对策略。 IP地址查询:IP数据云 - 免费IP地址查询…

目标检测评估指标

目录 一、检测精度1、TP、FP、TN、FN概念正样本和负样本TP(True Positive---正确的正向预测)FP(False Positive---错误的正向预测)FN(False Negative---错误的负向预测)TN(True Negative---正确的负向预测) 2、Precision(准确率)和Recall(召回率)3、P-R curve &…

【开源】SpringBoot框架开发网上药店系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 药品类型模块2.3 药品档案模块2.4 药品订单模块2.5 药品收藏模块2.6 药品资讯模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 角色表3.2.2 药品表3.2.3 药品订单表3.2.4 药品收藏表3.2.5 药品留言表…

归并排序总结

1.归并排序 归并排序的步骤如下: ①枚举中点,将区间分为左右两段; ②对左右两段区间分别排序; 这个过程以递归的方式进行。 ③合并两段区间。 是一个模拟的过程。用两个指针分别指向左右区间,判断当前哪个数小&…

【Linux C | 网络编程】广播概念、UDP实现广播的C语言例子

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 🤣本文内容🤣&a…

酷开科技服务升级,酷开系统给消费者更好的使用体验!

看电视的时候你是不是也会有选择困难症?不知道要看哪个?不知道如何操作?体验不够顺畅?现在,有了酷开系统9.2,这些通通不再是问题!酷开科技,一直致力于服务升级,给消费者更…

springBoot整合Redis(三、整合Spring Cache)

缓存的框架太多了,各有各的优势,比如Redis、Memcached、Guava、Caffeine等等。 如果我们的程序想要使用缓存,就要与这些框架耦合。聪明的架构师已经在利用接口来降低耦合了,利用面向对象的抽象和多态的特性,做到业务代…

uniapp+vue3+vites使用lime-echart问题记录

问题记录 1.vue3使用echarts,H5和微信小程序兼容问题 1.vue3使用echarts,H5和微信小程序兼容问题 问题描述,正常使用echarts,H5正常,小程序报错 报错信息如下 解决方案: 注意要点一:vue3需要使用esm文件 地址&#x…

Elasticsearch:从 ES|QL 到 Python 数据帧

在我之前的文章 “Elasticsearch:ES|QL 查询展示”,我展示了如何在 Kibana 中使用 ES|QL 对索引来进行查询及统计。在很多的情况下,我们需要在客户端中来对数据进行查询,那么我们该怎么办呢?我们需要使用到 Elasticsea…

怎么将电脑excel文档内的数据转换为图片形式

你平时在办公室会遇到格式转换的问题吗?比如PDF转Word,WPS转PDF,PDF转TXT,图片转PDF等。边肖最近在工作过程中遇到了类似的问题。为了更方便的查看表格,需要将Excel表格转换成图片格式。遇到这样的问题,很多…

CPP编程-CPP11中的内存管理策略模型与名称空间管理探幽(时隔一年,再谈C++抽象内存模型)

CPP编程-CPP11中的内存管理策略模型与名称空间管理探幽 CPP的四大内存分区模型 在 C 中,**内存分区是一种模型,用于描述程序运行时内存的逻辑组织方式,但在底层操作系统中,并不存在严格意义上的内存分区。**操作系统通常将内存分…

计算机大数据毕业设计-基于Flask的旅游推荐可视化系统的设计与实现

基于Flask的旅游推荐可视化系统的设计与实现 编程语言:Python3.10 涉及技术:FlaskMySQL8.0Echarts 开发工具:PyCharm 摘要:以Pycharm为旅游推荐系统开发工具,采用B/S结构,使用Python语言开发旅游景点推…

【python】成功解决ModuleNotFoundError: No module named ‘tensorboardX‘

【python】成功解决ModuleNotFoundError: No module named ‘tensorboardX’ 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程&#x1f…

二维码门楼牌管理系统应用场景:助力环保部门提升监管效率

文章目录 前言一、二维码门楼牌管理系统的环保应用场景二、二维码门楼牌管理系统如何助力环保监管三、二维码门楼牌管理系统与环保部门的联动效应 前言 随着城市化进程的加速,环保问题日益受到人们的关注。二维码门楼牌管理系统的出现,为环保部门提供了…

【论文阅读】DeepLab:语义图像分割与深度卷积网络,自然卷积,和完全连接的crf

【论文阅读】DeepLab:语义图像分割与深度卷积网络,自然卷积,和完全连接的crf 文章目录 【论文阅读】DeepLab:语义图像分割与深度卷积网络,自然卷积,和完全连接的crf一、介绍二、联系工作三、方法3.1 整体结构3.2 使用空间金字塔池…

机器人期刊:Science Robotics and IEEE Transactions

文章目录 1. Science Robotics (出版商 AAAS)2. IEEE Transactions on RoboticsReference1. Science Robotics (出版商 AAAS) https://www.science.org/journal/scirobotics 2. IEEE Transactions on Robotics