昇思25天学习打卡营第11天 |昇思MindSpore DCGAN 生成漫画头像学习

news2024/9/9 6:19:50

一、GAN 基础原理
DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器分层的卷积层BatchNorm层LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层BatchNorm层ReLU激活层组成。输入是标准正态分布中提取出的隐向量z
,输出是3x64x64的RGB图像。

二、DCGAN 原理

  1. 简介
    • 深度卷积对抗生成网络,是 GAN 的扩展。
    • 判别器由卷积、BatchNorm 层和 LeakyReLU 激活层组成,输入图像,输出图像为真的概率。
    • 生成器由转置卷积、BatchNorm 层和 ReLU 激活层组成,输入隐向量,输出 RGB 图像。
  2. 论文
    • 最早由 Radford 等人在论文 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks 中描述。

三、数据准备与处理

  1. 数据下载与解压
    • 使用 download 库的 download 函数下载动漫头像数据集并解压。
    • 函数download(url, path, kind, replace)
    • 参数
      • url:数据集的下载链接。
      • path:保存数据集的路径。
      • kind:数据的类型,如 zip
      • replace:是否替换已存在的文件。
    • 示例download("https://download.mindspore.cn/dataset/Faces/faces.zip", "./faces", kind="zip", replace=True)
  2. 数据处理
    • 定义输入参数:
      • batch_size = 128:批量大小。
      • image_size = 64:训练图像空间大小。
      • nc = 3:图像彩色通道数。
      • nz = 100:隐向量的长度。
      • ngf = 64:特征图在生成器中的大小。
      • ndf = 64:特征图在判别器中的大小。
      • num_epochs = 10:训练周期数。
      • lr = 0.0002:学习率。
      • beta1 = 0.5:Adam 优化器的 beta1 超参数。
    • create_dataset_imagenet 函数:
      • 功能:对数据进行加载、增强和批量处理。
      • 参数dataset_path,数据集的路径。
      • 示例create_dataset_imagenet('./faces')
      • 内部使用 mindspore.dataset 中的模块进行操作:
        • ImageFolderDataset :加载数据集。
        • 各种数据变换操作,如 ResizeCenterCrop 等。
    • 数据可视化:
      • 使用 create_dict_iterator 将数据转换为字典迭代器。
      • 通过 plot_data 函数和 matplotlib.pyplot 可视化部分训练数据。

四、构造网络

  1. 生成器
    • 功能:将隐向量映射到数据空间,生成 RGB 图像。
    • 结构:由一系列 Conv2dTranspose 转置卷积层、BatchNorm2d 层和 ReLU 激活层组成,输出经过 tanh 函数。
    • 代码实现:使用 mindspore 中的模块定义 Generator 类。
    • 函数Generator 类的 __init__construct 方法。
    • 参数:无。
    • 示例generator = Generator()
  2. 判别器
    • 功能:二分类网络,输出图像为真实图的概率。
    • 结构:由一系列 Conv2dBatchNorm2dLeakyReLU 层组成,最后通过 Sigmoid 激活函数。
    • 代码实现:使用 mindspore 中的模块定义 Discriminator 类。
    • 函数Discriminator 类的 __init__construct 方法。
    • 参数:无。
    • 示例discriminator = Discriminator()

五、模型训练

  1. 损失函数
    • 使用 BCELoss 作为损失函数。
    • 函数nn.BCELoss(reduction='mean')
    • 参数reduction ,指定损失的归约方式,如 'mean' 计算平均值。
    • 示例adversarial_loss = nn.BCELoss(reduction='mean')
  2. 优化器
    • 为生成器和判别器分别设置 Adam 优化器,学习率 lr = 0.0002beta1 = 0.5
    • 函数nn.Adam(params, learning_rate, beta1)
    • 参数
      • params :可训练的参数。
      • learning_rate :学习率。
      • beta1 :Adam 优化器的超参数。
    • 示例
      • optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
      • optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
  3. 训练模型
    • 分为训练判别器和训练生成器两部分。
      • 训练判别器:最大化 logD(x) + log(1 - D(G(z))) 的值,提高判别图像真伪的概率。
      • 训练生成器:最小化 log(1 - D(G(z))) ,产生更好的虚假图像。
    • 实现逻辑:
      • generator_forward 函数:计算生成器的损失和生成的图像。
      • discriminator_forward 函数:计算判别器的损失。
      • grad_generator_fngrad_discriminator_fn :计算梯度。
      • train_step 函数:执行训练步骤,更新参数。
    • 训练过程:
      • 循环训练网络,每 50 次迭代收集损失。
      • 每个 epoch 结束后,生成一组图片并保存模型参数。

六、结果展示

  1. 绘制损失函数图像:展示判别器和生成器损失与训练迭代的关系。
  2. 生成动画:可视化训练过程中生成的图像。
  3. 加载模型生成图像:通过加载生成器网络模型参数文件生成新的图像。

七、调用的库及功能

  1. download :用于数据集的下载和解压。
  2. mindspore.dataset
    • ImageFolderDataset :加载数据集。
    • shufflenum_parallel_workers 等:数据处理相关设置。
    • 各种数据变换操作,如 ResizeCenterCrop 等。
  3. mindspore
    • 构建神经网络,如 Conv2dTransposeBatchNorm2d 等。
    • 定义损失函数和优化器。
    • 执行梯度计算和参数更新。
  4. matplotlib.pyplot :数据可视化,绘制损失曲线和图像。
  5. matplotlib.animation :生成动画展示训练过程中的图像变化。

八、操作流程

  1. 准备数据:
    • 下载数据集:使用 download 函数下载数据。
    • 处理数据:调用 create_dataset_imagenet 函数进行加载、增强和批量处理。
    • 可视化数据:使用 plot_data 函数展示部分数据。
  2. 构建网络:
    • 定义生成器 Generator 类。
    • 定义判别器 Discriminator 类。
    • 实例化生成器和判别器。
  3. 训练模型:
    • 定义损失函数。
    • 设置优化器。
    • 执行训练逻辑,循环更新参数,保存模型。
  4. 结果展示:
    • 绘制损失曲线。
    • 生成动画展示生成图像的变化。
    • 加载模型生成新图像。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1962208.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【git】git常用命令提交规范

Git 是程序员工作中不可或缺的版本控制工具,以下是一些优化后的常用 Git 命令列表,旨在帮助你更高效地使用 Git 进行版本控制。 基础操作 拉取代码 git clone xxx.git创建分支 git branch dev切换分支 git checkout dev # 或者 git switch dev创建并切换…

Python酷库之旅-第三方库Pandas(056)

目录 一、用法精讲 211、pandas.Series.truncate方法 211-1、语法 211-2、参数 211-3、功能 211-4、返回值 211-5、说明 211-6、用法 211-6-1、数据准备 211-6-2、代码示例 211-6-3、结果输出 212、pandas.Series.where方法 212-1、语法 212-2、参数 212-3、功能…

论报文加密加签场景下如何高效的进行渗透测试

前言 最新的测试中,经常遇到HTTP报文加密/加签传输的情况,这导致想要查看和修改明文报文很不方便。 之前应对这种情况我们有几种常见的办法解决,比如使用burpy插件、在Burp上下游使用mitmproxy进行代理等,但这些使用起来不太方便…

LSTM详解总结

LSTM(Long Short-Term Memory)是一种用于处理和预测时间序列数据的递归神经网络(RNN)的改进版本。其设计初衷是为了解决普通RNN在长序列训练中出现的梯度消失和梯度爆炸问题。以下是对LSTM的详细解释,包括原理、公式、…

面向非结构化数据的知迟抽取

文章目录 实体抽取关系抽取事件抽取大量的数据以非结构化数据(即自由文本)的形式存在,如新闻报道、科技文献和政府文件等,面向文本数据的知识抽取一直是广受关注的问题。在前文介绍的知识抽取领域的评测竞赛中,评测数据大多属于非结构化文本数据。本节将对这一类知识抽取技…

Prometheus-部署

Prometheus-部署 Server端安装配置部署Node Exporters监控系统指标监控MySQL数据库监控nginx安装grafana Server端安装配置 1、上传安装包,并解压 cd /opt/ tar xf prometheus-2.30.3.linux-amd64.tar.gz mv prometheus-2.30.3.linux-amd64 /usr/local/prometheus…

【音频识别】十大数据集合集,宝藏合集,不容错过!

本文将为您介绍10个经典、热门的数据集,希望对您在选择适合的数据集时有所帮助。 1 RenderMe-360 发布方: 上海人工智能实验室 发布时间: 2023-05-24 简介: RenFace是一个大规模多视角人脸高清视频数据集,包含多样的…

2024年最强网络安全学习路线,详细到直接上清华的教材!

关键词:网络安全入门、渗透测试学习、零基础学安全、网络安全学习路线 首先咱们聊聊,学习网络安全方向通常会有哪些问题前排提示:文末有CSDN官方认证Python入门资料包 ! 1、打基础时间太长 学基础花费很长时间,光语…

Redis内存管理

文章目录 Redis内存管理删除策略淘汰策略LRU算法 Redis内存管理 长期把Redis做缓存用,总有一天Redis内存总会满的。有没有思考过这个问题,Redis内存满了会怎么样?在redis.conf中把Redis内存设置为1个字节,做一个测试:…

【随机链表的复制】python刷题记录

R3-哈希表 参考k神题解 哈希表法: """ # Definition for a Node. class Node:def __init__(self, x: int, next: Node None, random: Node None):self.val int(x)self.next nextself.random random """class Solution:def copy…

“打破常规:评估八股文对工作的真正影响“

“八股文”在实际工作中是助力、阻力还是空谈? 作为现在各类大中小企业面试程序员时的必问内容,“八股文”似乎是很重要的存在。但“八股文”是否能在实际工作中发挥它“敲门砖”应有的作用呢?有IT人士不禁发出疑问:程序员面试考什…

基于深度学习的结肠炎严重度诊断

基于深度学习的结肠炎严重度诊断 本文所涉及所有资源均在传知代码平台可获取 文章目录 基于深度学习的结肠炎严重度诊断1.概述1.1 数据集展示1.2 Resnet50介绍1.2.1结构与特点1.2.2关键优势1.2.3总结 2.创新点3.结果可视化展示结果展示4.核心逻辑5.部署及使用方式5.1 环境配置5…

彻底搞清楚SSR同构渲染的首屏

作为.NET技术栈的全干工程师,Blazor、Vue/Nuxt.js和React/Next.js都会接触到。它们(准确的说是Blazor、Nuxt和Next),都实现了SSR同构渲染。要了解同构渲染,需要从服务端渲染开始。 传统的服务端渲染 如下图所示&…

开放式耳机什么牌子的好?看这6大品牌就够了

移动互联网时代,听歌、追剧、网课、短视频……这几年全球青年人对于耳机和耳朵的依赖程度,可谓前所未有的提升。但选择一款好的耳机,也不是一件容易的事,入耳式耳机戴久了耳道会疼,还可能引起一系列不必要的炎症&#…

【C语言】C语言期末突击/考研--数据的类型

目录 一、编程环境的搭建 二、数据的类型、数据的输入输出 2.1.数据类型 2.2.常量 2.3.变量 2.4.整型数据 2.4.1.符号常量 2.4.2.整型变量 2.5.浮点型数据 2.5.1.浮点型常量 2.5.2.浮点型变量 2.6.字符型数据 2.6.1字符型常量 2.6.2.字符数据在内存中的存储形式及…

Python 【机器学习】 进阶 之 【实战案例】房价数据中位数分析 | 1/3(含分析过程)

Python 【机器学习】 进阶 之 【实战案例】房价数据中位数分析 | 1/3(含分析过程) 目录 Python 【机器学习】 进阶 之 【实战案例】房价数据中位数分析 | 1/3(含分析过程) 一、简单介绍 二、机器学习 1、为什么使用机器学习&a…

react antd upload custom request处理多个文件上传

react antd upload custom request处理多个文件上传的问题 背景:第一次请求需要请求后端返回aws 一个link,再往link push文件,再调用另一个接口告诉后端已经上传成功,拿到返回值。 再把返回值传给业务api... 多文件上传一直是循环…

字体表绘制的理解

下载字体到项目根目录下,我们通过一些在写预览本地字体的网站,简单看一下 通过图片不难看出阴书与原文的对应关系,接下来通过程序去完成这一过程,通过 fonttools 处理 ttf,然后获取字体和文字对应的 xml 文件 下面简单…

分布式SQL查询引擎之ByConity

ByConity 是字节跳动面向现代数据栈的一款开源数仓系统,应用了大量数据库成熟技术,如列存引擎,MPP 执行,智能查询优化,向量化执行,Codegen,indexing,数据压缩,适合用于 O…

线程池和进程池,输出有区别吗?

from concurrent.futures import ThreadPoolExecutor,ProcessPoolExecutor def fn(name):for i in range(1000):print(name,i)if __name__ __main__:with ThreadPoolExecutor(10) as t:for i in range(100):t.submit(fn,namef"线程{i}")with ProcessPoolExecutor(10…