昇思MindSpore学习笔记4-02生成式--DCGAN生成漫画头像

news2025/1/15 17:39:14

摘要:

        记录了昇思MindSpore AI框架使用70171张动漫头像图片训练一个DCGAN神经网络生成式对抗网络,并用来生成漫画头像的过程、步骤。包括环境准备、下载数据集、加载数据和预处理、构造网络、模型训练等。

一、概念

深度卷积对抗生成网络DCGAN

Deep Convolutional Generative Adversarial Networks

        扩展GAN

        判别器

                组成

                        卷积层

                        BatchNorm层

                        LeakyReLU激活层

                功能

                        输入是3*64*64图像

                        输出是真图像概率

        生成器

                组成

                        转置卷积层

                        BatchNorm层

                        ReLU激活层

                功能

                        输入是标准正态分布中提取出的隐向量z

                        输出是3*64*64 RGB图像。

  • 环境准备
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

三、数据准备与处理

1.下载数据集

下载到指定目录下并解压代码如下

from download import download
url = "https://download.mindspore.cn/dataset/Faces/faces.zip"
path = download(url, "./faces", kind="zip", replace=True)

输出:

Downloading data from https://download-mindspore.osinfra.cn/dataset/Faces/faces.zip (274.6 MB)

file_sizes: 100%|████████████████████████████| 288M/288M [00:52<00:00, 5.49MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./faces

2.数据集介绍

使用的动漫头像数据集共有70,171张动漫头像图片,图片大小均为96*96。

数据集目录结构如下:

./faces/faces
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
    ...
├── 70169.jpg
└── 70170.jpg

3.数据处理

(1) 执行过程参数定义:

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3            # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

(2) 数据处理和增强

create_dataset_imagenet函数

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
​
def create_dataset_imagenet(dataset_path):
    """数据加载"""
    dataset = ds.ImageFolderDataset(dataset_path,
                                    num_parallel_workers=4,
                                    shuffle=True,
                                    decode=True)
​
    # 数据增强操作
    transforms = [
        vision.Resize(image_size),
        vision.CenterCrop(image_size),
        vision.HWC2CHW(),
        lambda x: ((x / 255).astype("float32"))
    ]
​
    # 数据映射操作
    dataset = dataset.project('image')
    dataset = dataset.map(transforms, 'image')
​
    # 批量操作
    dataset = dataset.batch(batch_size)
    return dataset
​
dataset = create_dataset_imagenet('./faces')

(3) 查看训练数据

matplotlib模块

数据转换成字典迭代器

        create_dict_iterator函数

import matplotlib.pyplot as plt
​
def plot_data(data):
    # 可视化部分训练数据
    plt.figure(figsize=(10, 3), dpi=140)
    for i, image in enumerate(data[0][:30], 1):
        plt.subplot(3, 10, i)
        plt.axis("off")
        plt.imshow(image.transpose(1, 2, 0))
    plt.show()
​
sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

四、构造网络

模型权重随机初始化

范围:mean为0,sigma为0.02的正态分布【数学不好】

1. 生成器

生成器G

        隐向量z映射数据空间

        数据源是图像

        生成与图像大小相同的 RGB 图像

        Conv2dTranspose转置卷积层

        每个层与BatchNorm2d层和ReLu激活层配对

        tanh函数

        输出[-1,1]范围内数据

DCGAN生成图像过程如下所示:

生成器结构参数:

        nz         隐向量z的长度

        ngf         有关生成器传播的特征图大小

        nc         输出图像通道数

生成器代码:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal
​
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
​
class Generator(nn.Cell):
    """DCGAN网络生成器"""
​
    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.SequentialCell(
            nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.Tanh()
            )
​
    def construct(self, x):
        return self.generator(x)
​
generator = Generator()

2. 判别器

判别器D

        二分类网络模型

                Conv2d

                BatchNorm2d

                LeakyReLU

                Sigmoid激活函数

                输出判定图像真实概率

判别器代码:

class Discriminator(nn.Cell):
    """DCGAN网络判别器"""
​
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.SequentialCell(
            nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),
            )
        self.adv_layer = nn.Sigmoid()
​
    def construct(self, x):
        out = self.discriminator(x)
        out = out.reshape(out.shape[0], -1)
        return self.adv_layer(out)
​
discriminator = Discriminator()

五、模型训练

1. 损失函数

二进制交叉熵损失函数MindSpore.nn.BCELoss

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

2. 优化器

Adam优化器

        lr = 0.0002

        beta1 = 0.5

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

3. 训练模型

训练判别器

        提高判别图像真伪的概率

        Goodfellow方法:提高随机梯度更新判别器

        最大化logD(x)+log(1-D(G(z)))

训练生成器

        最小化log(1−D(G(z)))

        产生更好的虚拟图像

两个部分分别

        获取训练损失

        每个周期结束统计

        批量推送fixed_noise到生成器

        跟踪G的训练进度

模型训练正向逻辑:

def generator_forward(real_imgs, valid):
    # 将噪声采样为发生器的输入
    z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))
​
    # 生成一批图像
    gen_imgs = generator(z)
​
    # 损失衡量发生器绕过判别器的能力
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
​
    return g_loss, gen_imgs
​
def discriminator_forward(real_imgs, gen_imgs, valid, fake):
    # 衡量鉴别器从生成的样本中对真实样本进行分类的能力
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
    d_loss = (real_loss + fake_loss) / 2
    return d_loss
​
grad_generator_fn = ms.value_and_grad(generator_forward, None,
                                      optimizer_G.parameters,
                                      has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,
                                          optimizer_D.parameters)
​
@ms.jit
def train_step(imgs):
    valid = ops.ones((imgs.shape[0], 1), mindspore.float32)
    fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)
​
    (g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)
    optimizer_G(g_grads)
    d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)
    optimizer_D(d_grads)
​
    return g_loss, d_loss, gen_imgs

循环训练网络

        迭代50次收集生成器、判别器的损失一次

        绘制损失函数的图像

import mindspore
​
G_losses = []
D_losses = []
image_list = []
​
total = dataset.get_dataset_size()
for epoch in range(num_epochs):
    generator.set_train()
    discriminator.set_train()
    # 为每轮训练读入数据
    for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):
        g_loss, d_loss, gen_imgs = train_step(imgs)
        if i % 100 == 0 or i == total - 1:
            # 输出训练记录
            print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (
                epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))
        D_losses.append(d_loss.asnumpy())
        G_losses.append(g_loss.asnumpy())
​
    # 每个epoch结束后,使用生成器生成一组图片
    generator.set_train(False)
    fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
    img = generator(fixed_noise)
    image_list.append(img.transpose(0, 2, 3, 1).asnumpy())
​
    # 保存网络模型参数为ckpt文件
    mindspore.save_checkpoint(generator, "./generator.ckpt")
    mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

输出:

[ 1/3][  1/549]   Loss_D: 0.2635  Loss_G: 4.8150
[ 1/3][101/549]   Loss_D: 0.4023  Loss_G: 4.9807
[ 1/3][201/549]   Loss_D: 0.2425  Loss_G: 1.6335
[ 1/3][301/549]   Loss_D: 0.5856  Loss_G: 0.6079
[ 1/3][401/549]   Loss_D: 0.1922  Loss_G: 4.3977
[ 1/3][501/549]   Loss_D: 0.1065  Loss_G: 2.3724
[ 1/3][549/549]   Loss_D: 0.1893  Loss_G: 1.6483
[ 2/3][  1/549]   Loss_D: 0.3370  Loss_G: 4.4347
[ 2/3][101/549]   Loss_D: 0.4681  Loss_G: 0.8623
[ 2/3][201/549]   Loss_D: 0.1856  Loss_G: 3.7501
[ 2/3][301/549]   Loss_D: 0.1932  Loss_G: 2.6333
[ 2/3][401/549]   Loss_D: 0.1310  Loss_G: 2.2524
[ 2/3][501/549]   Loss_D: 0.2531  Loss_G: 1.4690
[ 2/3][549/549]   Loss_D: 0.1192  Loss_G: 5.7166
[ 3/3][  1/549]   Loss_D: 0.0716  Loss_G: 2.9886
[ 3/3][101/549]   Loss_D: 0.1345  Loss_G: 2.6544
[ 3/3][201/549]   Loss_D: 0.1097  Loss_G: 2.8604
[ 3/3][301/549]   Loss_D: 0.2066  Loss_G: 6.1513
[ 3/3][401/549]   Loss_D: 0.0797  Loss_G: 3.2336
[ 3/3][501/549]   Loss_D: 0.2618  Loss_G: 4.0991
[ 3/3][549/549]   Loss_D: 0.5600  Loss_G:10.7509

4. 结果展示

描绘D和G损失与训练迭代的关系图:

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出:

显示隐向量fixed_noise训练生成的图像

import matplotlib.pyplot as plt
import matplotlib.animation as animation
​
def showGif(image_list):
    show_list = []
    fig = plt.figure(figsize=(8, 3), dpi=120)
    for epoch in range(len(image_list)):
        images = []
        for i in range(3):
            row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)
            images.append(row)
        img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
        plt.axis("off")
        show_list.append([plt.imshow(img)])
​
    ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
    ani.save('./dcgan.gif', writer='pillow', fps=1)
​
showGif(image_list)

输出:

训练次数增多,图像质量越好

num_epochs达到50以上,生成动漫头像图片与数据集较为相似

加载生成器网络模型参数文件来生成图像代码:

# 从文件中获取模型参数并加载到网络中
mindspore.load_checkpoint("./generator.ckpt", generator)
​
fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
img64 = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()
​
fig = plt.figure(figsize=(8, 3), dpi=120)
images = []
for i in range(3):
    images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
plt.axis("off")
plt.imshow(img)
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1897077.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

个人引导页+音乐炫酷播放器(附加源码)

个人引导页音乐炫酷播放器 效果图部分源码完整源码领取下期更新内容 效果图 部分源码 //网站动态标题开始 var OriginTitile document.title, titleTime; document.addEventListener("visibilitychange", function() {if (document.hidden) {document.title "…

[作业]10 枚举-排列类

作业&#xff1a; 已做&#xff1a; #include <iostream> using namespace std; int n; int a[100]; void func(int ,int); int main(){cin>>n;func(0,n);return 0; } void func(int k,int m){if(k>m-1){for(int i0;i<m;i){cout<<a[i];}cout<<en…

【高性能服务器】select模型

&#x1f525;博客主页&#xff1a; 我要成为C领域大神&#x1f3a5;系列专栏&#xff1a;【C核心编程】 【计算机网络】 【Linux编程】 【操作系统】 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 本博客致力于知识分享&#xff0c;与更多的人进行学习交流 IO多路复用就是复用…

一文了解常见DNS问题

当企业的DNS出现故障时&#xff0c;为不影响企业的正常运行&#xff0c;团队需要能够快速确定问题的性质和范围。那么有哪些常见的DNS问题呢&#xff1f; 域名解析失败&#xff1a; 当您输入一个域名&#xff0c;但无法获取到与之对应的IP地址&#xff0c;导致无法访问相应的网…

INTERCONNECT 使用脚本导入 Element Library 的器件

INTERCONNECT 使用脚本导入 Element Library 的器件 正文示例1示例2正文 在 INTERCONNECT 添加自定义器件到 Custom 文件夹下 一文中,我们介绍了如何将器件或者自定义器件添加到用户自定义的库中。那么我们如何从 Element Library 中导入我们需要的器件呢? 最简单的方式就是…

Linux系统(CentOS)安装iptables防火墙

1&#xff0c;先检查是否安装了iptables 检查安装文件-执行命令&#xff1a;rpm -qa|grep iptables 检查安装文件-执行命令&#xff1a;service iptables status 2&#xff0c;如果安装了就卸装(iptables-1.4.21-35.el7.x86_64 是上面命令查出来的版本) 执行命令&#xff1a…

Logstash安装插件失败的问题

Logstash安装插件失败的问题 安装 logstash-output-jdbc 失败 报错为&#xff1a; Unable to download data from https://rubygems.org - Net::OpenTimeout: Failed to open TCP connection to rubygems.org:443 (execution expired) (https://rubygems.org/latest_specs.4.…

正确使用Pytorch Geometric打开Cora(Planetoid)数据集

文章目录 关于报错&#xff08;"Cannot connect to host"&#xff09;解决方法 关于报错&#xff08;“Cannot connect to host”&#xff09; 我们在使用PyG调用Planetoid数据集的时候&#xff0c;常会碰到如下报错&#xff1a; 解决方法就是手动下载这个数据集。…

CentOS 离线安装部署 MySQL 8详细教程

1、简介 MySQL是一个流行的开源关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;它基于SQL&#xff08;Structured Query Language&#xff0c;结构化查询语言&#xff09;进行操作。MySQL最初由瑞典的MySQL AB公司开发&#xff0c;后来被Sun Microsystems公司…

一个开源的、独立的、可自托管的评论系统,专为现代Web平台设计

大家好&#xff0c;今天给大家分享的是一个开源的、独立的、可自托管的评论系统&#xff0c;专为现代Web平台设计。 Remark42是一个自托管的、轻量级的、简单的&#xff08;但功能强大的&#xff09;评论引擎&#xff0c;它不会监视用户。它可以嵌入到博客、文章或任何其他读者…

Eclipse配置Tomcat时无Apache选项问题

有可能你会遇到&#xff0c;安装最新版本Eclipse&#xff0c;但是 Window——Preferences——Servers——Runtime Environments。发现没有Apache选项。&#xff0c;这是因为&#xff0c;默认没有安装J2EE组件&#xff0c;我们可以通过手动安装&#xff0c;来解决这个问题。 一…

哪些场景下可以更好地使用行列视(RCV)报表工具呢?

行列视产品是我们公司自主研发的一套基于HTML5技术的Excel式web生产报表应用系统&#xff0c;这款产品定位于发电企业生产指标的收集、报表制作和指标报表可视化&#xff0c;是国内首套专业化、自助化、智能化的生产指标管理及分析应用平台。功能强大但是却简单易用。 这款产品…

react_web自定义组件_多类型Modal_搜索栏Search

目录 一、带输入框的Modal 二、提示框Modal 三、搜索栏Search 在做项目时引入一些现成的UI组件&#xff0c;但是如果和设计图冲突太大&#xff0c;更改时很麻烦&#xff0c;如果自己写一个通用组件其实也就几十分钟或者几个小时&#xff0c;而且更具UI设计更改也比较好更改&…

win11如何关闭自动更新,延长暂停更新时间

网上有很多关闭自动更新的方法&#xff0c;今天给大家带来另一种关闭win11自动更新的方法。 1.winR打开运行窗口&#xff0c;输入regedit打开注册表 2.定位到以下位置&#xff1a; 计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\WindowsUpdate\UX\Settings 3.右键右边空白&…

使用kali Linux启动盘轻松破解Windows电脑密码

破解分析文章仅限用于学习和研究目的&#xff1b;不得将上述内容用于商业或者非法用途&#xff0c;否则&#xff0c;一切后果请用户自负。谢谢&#xff01;&#xff01; 效果展示&#xff1a; 使用kali Linux可以轻松破解Windows用户及密码 准备阶段&#xff1a; &#xff08…

Dungeonborne卡顿怎么办 快速解决Dungeonborne卡顿问题

随着Dungeonborne游戏剧情的深入&#xff0c;玩家将逐渐解锁更多的地图和副本&#xff0c;每个区域都有其独特的生态和敌人。在探索的过程中&#xff0c;玩家不仅可以获得强大的装备和道具&#xff0c;还能结识到志同道合的伙伴&#xff0c;共同面对更强大的敌人。不过也有玩家…

【项目实践】贪吃蛇

一、游戏效果展示二、博客目标三、使用到的知识四、Win32 API 介绍 4.1 WIn32 API4.2 控制台程序4.3 控制屏幕上的坐标COORD4.4 GetStdHandle4.5 GetConsoleCursorInfo 4.5.1 CONSOLE_CURSOR_INFO 4.6 SetConsoleCursorInfo4.7 SetConsoleCursorPosition4.8 GetAsyncKeyState 五…

vue3+electron项目搭建,遇到的坑

我主要是写后端,所以对前端的vue啊vue-cli只是知其然,不知其所以然 这样也导致了我在开发前端时候遇到了很多的坑 第一个坑, vue2升级vue3始终升级不成功 第二个坑, vue add electron-builder一直卡进度,进度条走完就是不出提示succes 第一个坑的解决办法: 按照网上说的升级v…

如何通过KB知识库系统实现内部知识的管理

“Baklib 通过构建KB知识库系统实现内部知识的管理&#xff0c;构建 CMS 系统实现网站内容管理&#xff0c;构建 DAM 实现对原子化数字内容的管理。” Baklib 从多个维度和深度实现对数字内容的管理。 CMS 系统 CMS 系统(Content Management System 内容管理系统)是一种帮助用…

什么是 HTTP POST 请求?初学者指南与示范

在现代网络开发领域&#xff0c;理解并应用 HTTP 请求 方法是基本的要求&#xff0c;其中 "POST" 方法扮演着关键角色。 理解 POST 方法 POST 方法属于 HTTP 协议的一部分&#xff0c;主旨在于向服务器发送数据以执行资源的创建或更新。它与 GET 方法区分开来&…