昇思25天学习打卡营第22天|GAN图像生成

news2024/11/26 23:32:16

今天是参加昇思25天学习打卡营的第22天,今天打卡的课程是“GAN图像生成”,这里做一个简单的分享。

1.简介

今天来学习“GAN图像生成”,这是一个基础的生成式模型。

生成式对抗网络(Generative Adversarial Networks,GAN)是一种生成式机器学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。

最初,GAN由Ian J. Goodfellow于2014年发明,并在论文Generative Adversarial Nets中首次进行了描述,其主要由两个不同的模型共同组成——生成器(Generative Model)和判别器(Discriminative Model):

  • 生成器的任务是生成看起来像训练图像的“假”图像;
  • 判别器需要判断从生成器输出的图像是真实的训练图像还是虚假的图像。

GAN通过设计生成模型和判别模型这两个模块,使其互相博弈学习产生了相当好的输出。

2.模型架构

  • 模型原理

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型 𝐺 和估计样本是否来自训练数据的判别模型 𝐷。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

用 𝑥 代表图像数据,用 𝐷(𝑥)表示判别器网络给出图像判定为真实图像的概率。在判别过程中,𝐷(𝑥) 需要处理作为二进制文件的大小为 1×28×28的图像数据。当 𝑥 来自训练数据时,𝐷(𝑥) 数值应该趋近于 1 ;而当 𝑥 来自生成器时,𝐷(𝑥)𝐷数值应该趋近于 00 。因此 𝐷(𝑥) 也可以被认为是传统的二分类器。

用 𝑧 代表标准正态分布中提取出的隐码(隐向量),用 𝐺(𝑧):表示将隐码(隐向量) 𝑧 映射到数据空间的生成器函数。函数 𝐺(𝑧) 的目标是将服从高斯分布的随机噪声 𝑧 通过生成网络变换为近似于真实分布 𝑝𝑑𝑎𝑡𝑎(𝑥) 的数据分布,我们希望找到 θ 使得 𝑝𝐺(𝑥;𝜃) 和𝑝𝑑𝑎𝑡𝑎(𝑥) 尽可能的接近,其中𝜃 代表网络参数。

𝐷(𝐺(𝑧))表示生成器 𝐺𝐺生成的假图像被判定为真实图像的概率,如Generative Adversarial Nets中所述,𝐷 和 𝐺 在进行一场博弈,𝐷 想要最大程度的正确分类真图像与假图像,也就是参数 log⁡𝐷(𝑥);而 𝐺 试图欺骗 𝐷 来最小化假图像被识别到的概率,也就是参数log⁡(1−𝐷(𝐺(𝑧)))。因此GAN的损失函数为:
在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是𝑝𝐺(𝑥;𝜃)=𝑝𝑑𝑎𝑡𝑎(𝑥),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2。
  • 核心代码

生成器代码:

from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器代码:

# 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')
  • 损失函数和优化器
lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

3.小结

今天学了GAN用于图像生成的基本理论和编码方法。GAN模型由生成器(Generative Model)和判别器(Discriminative Model)构成两个相互对抗的模型。生成器负责生成图像进行,判别器用于判定图像真假,通过对抗的模式使得真假判定的结果接近1:1,进而完成训练。这样训练好的生成器即可用于图形生成。这里面着重要掌握对抗网络损失函数的意义,这是的对抗网络能够输出最正确结果的要点。

以上是第22天的学习内容,附上今日打卡记录:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1929782.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springboot系列九: 接收参数相关注解

文章目录 基本介绍接收参数相关注解应用实例PathVariableRequestHeaderRequestParamCookieValueRequestBodyRequestAttributeSessionAttribute 复杂参数基本介绍应用实例 自定义对象参数-自动封装基本介绍应用实例 ⬅️ 上一篇: springboot系列八: springboot静态资源访问&…

02-Redis未授权访问漏洞

免责声明 本文仅限于学习讨论与技术知识的分享,不得违反当地国家的法律法规。对于传播、利用文章中提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,本文作者不为此承担任何责任,一旦造成后果请自行承担&…

【Windows】Microsoft PC Manager

使用 Microsoft PC Manager,用户可以轻松执行基本的计算机维护,并通过一键操作提升设备速度。这款应用程序提供了一系列功能,包括磁盘清理、启动应用管理、病毒扫描、Windows 更新检查、进程监控和存储管理。 Microsoft PC Manager 的关键特…

React学习笔记03-----手动创建和运行

一、项目创建与运行【手动】 react-scripts集成了webpack、bable、提供测试服务器 1.目录结构 public是静态目录,提供可以供外部直接访问的文件,存放不需要webpack打包的文件,比如静态图片、CSS、JS src存放源码 (1&#xff09…

xss复习总结及ctfshow做题总结xss

xss复习总结 知识点 1.XSS 漏洞简介 ​ XSS又叫CSS(Cross Site Script)跨站脚本攻击是指恶意攻击者往Web页面里插入恶意Script代码,当用户浏览该页之时,嵌入其中Web里面的Script代码会被执行,从而达到恶意攻击用户的…

ASF平台

最近一直在研究滑坡,但是insar数据处理很麻烦,自己手动处理gamma有很慢,而且据师兄说,gamma处理还很看经验,我就又去看了很多python库和其他工具: isce mintpy pyint 也使用了asf上面的处理产品,虽然也…

H. Beppa and SwerChat【双指针】

思路分析&#xff1a;运用双指针从后往前扫一遍&#xff0c;两次分别记作数组a&#xff0c;b&#xff0c;分别使用双指针i和j来扫&#xff0c;如果一样就往前&#xff0c;如果不一样&#xff0c;i–,ans #include<iostream> #include<cstring> #include<string…

C#绘制含流动块的管道

1&#xff0c;效果。 2&#xff0c;绘制技巧。 1&#xff0c;流动块的实质是使用Pen的自定义DashStyle绘制的线&#xff0c;并使用线的偏移值呈现出流动的效果。 Pen barPen new Pen(BarColor, BarHeight);barPen.DashStyle DashStyle.Custom;barPen.DashOffset startOffse…

解读InnoDB数据库索引页与数据行的紧密关联

目录 一、快速走进索引页结构 &#xff08;一&#xff09;整体展示说明 &#xff08;二&#xff09;内容说明 File Header&#xff08;文件头部&#xff09; Page Header&#xff08;页面头部&#xff09; Infimum Supremum&#xff08;最小记录和最大记录&#xff09; …

太速科技-FMC207-基于FMC 两路QSFP+光纤收发子卡

FMC207-基于FMC 两路QSFP光纤收发子卡 一、板卡概述 本卡是一个FPGA夹层卡&#xff08;FMC&#xff09;模块&#xff0c;可提供高达2个QSFP / QSFP 模块接口&#xff0c;直接插入千兆位级收发器&#xff08;MGT&#xff09;的赛灵思FPGA。支持利用Spartan-6、Virtex-6、Kin…

Webpack看这篇就够了

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 非常期待和您一起在这个小…

java.sql.SQLException: Unknown system variable ‘query_cache_size‘【Pyspark】

1、问题描述 学习SparkSql中&#xff0c;将spark中dataframe数据结构保存为jdbc的格式并提交到本地的mysql中&#xff0c;相关代码见文章末尾。 运行代码时报出相关配置文件错误&#xff0c;如下。 根据该报错&#xff0c;发现网络上多数解决方都是基于java开发的解决方案&a…

GPT-4从0到1搭建一个Agent简介

GPT-4从0到1搭建一个Agent简介 1. 引言 在人工智能领域&#xff0c;Agent是一种能够感知环境并采取行动以实现特定目标的系统。本文将简单介绍如何基于GPT-4搭建一个Agent。 2. Agent的基本原理 Agent的核心是感知-行动循环&#xff08;Perception-Action Loop&#xff09;…

【Windows】系统盘空间不足?WizTree 和 DISM++ 来帮忙

当您的系统盘空间接近饱和时&#xff0c;了解硬盘空间的使用情况变得尤为重要。在这种情况下&#xff0c;您可以利用 Windows 内置的存储使用工具来快速查看哪些文件和应用程序占用了大量空间&#xff0c;并采取相应措施进行清理。此外&#xff0c;第三方工具如 WizTree 可以提…

Java NIO合并多个文件

NIO API java.nio (Java Platform SE 8 ) 直接上代码 package com.phil.aoplog.util;import lombok.extern.slf4j.Slf4j;import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.nio.channels.FileChannel;Slf4j public…

勒索防御第一关 亚信安全AE防毒墙全面升级 勒索检出率提升150%

亚信安全信舷AE高性能防毒墙完成能力升级&#xff0c;全面完善勒索边界“全生命周期”防御体系&#xff0c;筑造边界勒索防御第一关&#xff01; 勒索之殇&#xff0c;银狐当先 当前勒索病毒卷携着AI技术&#xff0c;融合“数字化”的运营模式&#xff0c;形成了肆虐全球的网…

数据结构进阶:使用链表实现栈和队列详解与示例(C, C#, C++)

文章目录 1、 栈与队列简介栈&#xff08;Stack&#xff09;队列&#xff08;Queue&#xff09; 2、使用链表实现栈C语言实现C#语言实现C语言实现 3、使用链表实现队列C语言实现C#语言实现C语言实现 4、链表实现栈和队列的性能分析时间复杂度空间复杂度性能特点与其他实现的比较…

VBA学习(21):遍历文件夹(和子文件夹)中的文件

很多时候&#xff0c;我们都想要遍历文件夹中的每个文件&#xff0c;例如在工作表中列出所有文件名、对每个文件进行修改。VBA给我们提供了一些方式&#xff1a;&#xff08;1&#xff09;Dir函数&#xff1b;&#xff08;2&#xff09;File System Object。 使用Dir函数 Dir…

31.RAM-IP核的配置、调用、仿真全流程

&#xff08;1&#xff09;RAM IP核简介 RAM是随机存取存储器&#xff08;Random Access Memory&#xff09;的简称&#xff0c;是一个易失性存储器&#xff0c;其工作时可以随时对任何一个指定地址写入或读出数据。&#xff08;掉电数据丢失&#xff09; &#xff08;2&#…

Spring Cloud Gateway 入门与实战

一、网关 在微服务框架中&#xff0c;网关是一个提供统一访问地址的组件&#xff0c;它充当了客户端和内部微服务之间的中介。网关主要负责流量路由和转发&#xff0c;将外部请求引导到相应的微服务实例上&#xff0c;同时提供一些功能&#xff0c;如身份认证、授权、限流、监…