使用生成式对抗网络(GAN)生成动漫人物图像

news2024/9/20 7:47:45

【图书推荐】《PyTorch深度学习与企业级项目实战》-CSDN博客

《PyTorch深度学习与企业级项目实战(人工智能技术丛书)》(宋立桓,宋立林)【摘要 书评 试读】- 京东图书 (jd.com)

如今AI艺术创作能力越来越强大,Google发布的ImageGen项目基于文本提示作画的结果和真实艺术家的成品难辨真假。本项目将使用PyTorch实现生成式对抗网络生成式对抗网络来完成AI生成动漫人物图像。

本项目中使用的数据集是一个由63 632个高质量动画人脸组成的数据集,从www.getchu.com中抓取,然后使用https://github.com/nagadomi/lbpcascade_animeface中的动画人脸检测算法进行裁剪。图像大小从90×90到120×120不等。该数据集包含高质量的动漫角色图像,具有干净的背景和丰富的颜色。数据集下载链接:https://github.com/bchao1/Anime-Face-Dataset。

我们知道在生成式对抗网络中有两个模型——生成模型(Generative Model,G)和判别模型(Discriminative Model,D)。G就是一个生成图片的网络,它接收一个随机的噪声z,然后通过这个噪声生成图片,生成的数据记作G(z)。D是一个判别网络,判别一幅图片是不是“真实的”(是不是捏造的)。它的输入参数是x,x代表一幅图片,输出D(x)代表x为真实图片的概率,如果为1,就代表是真实的图片,而输出为0,就代表不可能是真实的图片。

  1. 定义生成器Generator:生成器的输入为100维的高斯噪声,生成器会利用这个噪声生成指定大小的图片,关于最初的噪声,可以看成10011的特征图,然后利用转置卷积来进行尺寸还原操作,标准的卷积操作是不断缩小尺寸,转置卷积就可以理解为它的逆操作,这样就可以不断放大图像。
  2. 定义判别器Discriminator:判别器就是一个典型的二分类网络,首先它的输入是我们输入的图片,我们会利用一系列卷积操作来形成一维特征图进行分类操作,这里可以发现判别器的网络和生成器的相关操作是可逆的,唯独不一样的是激活函数。

模型训练的步骤如下:

   步骤1:首先固定生成器,训练判别器,提高真实样本被判别为真的概率,同时降低生成器生成的假图像被判别为真的概率,目标是判别器能准确进行分类。

   步骤2:固定判别器,训练生成器,生成器生成图像,尽可能提高该图像被判别器判别为真的概率,目标是生成器的结果能够骗过判别器。

   步骤3:重复,循环交替训练,最终生成器生成的样本足够逼真,使得鉴别器只有大约50%的判断正确率(相当于乱猜)。

完整代码如下:

#####################GANDEMO.py####################
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
import torchvision
from torchvision import transforms, datasets
from tqdm import tqdm

class Config(object):
    data_path = './gandata/data/'
    image_size = 96
    batch_size = 32
    epochs = 200
    lr1 = 2e-3
    lr2 = 2e-4
    beta1 = 0.5
    gpu = False
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    nz = 100
    ngf = 64
    ndf = 64
    save_path = './gandata/images'
    generator_path = './gandata/generator.pkl' 			#模型保存路径
    discriminator_path = './gandata/discriminator.pkl' 	#模型保存路径
    gen_img = './gandata/result.png'
    gen_num = 64
    gen_search_num = 5000
    gen_mean = 0
    gen_std = 1

config = Config()

# 1.数据转换
data_transform = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 2.形成训练集
train_dataset = datasets.ImageFolder(root=os.path.join(config.data_path),
                                     transform=data_transform)

# 3.形成迭代器
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           config.batch_size,
                                           True,
                                           drop_last=True)
print('using {} images for training.'.format(len(train_dataset)))

class Generator(nn.Module):
    def __init__(self, config):
        super().__init__()

        ngf = config.ngf

        self.model = nn.Sequential(
            nn.ConvTranspose2d(config.nz, ngf * 8, 4, 1, 0),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1),
            nn.Tanh()
        )

    def forward(self, x):
        output = self.model(x)
        return output


class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()

        ndf = config.ndf

        self.model = nn.Sequential(
            nn.Conv2d(3, ndf, 5, 3, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0)
        )

    def forward(self, x):
        output = self.model(x)
        return output.view(-1)

generator = Generator(config)
discriminator = Discriminator(config)

optimizer_generator = torch.optim.Adam(generator.parameters(),
                                       config.lr1,
                                       betas=(config.beta1, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                           config.lr2,
                                           betas=(config.beta1, 0.999))

true_labels = torch.ones(config.batch_size)
fake_labels = torch.zeros(config.batch_size)
fix_noises = torch.randn(config.batch_size, config.nz, 1, 1)
noises = torch.randn(config.batch_size, config.nz, 1, 1)

for epoch in range(config.epochs):
    for ii, (img, _) in tqdm(enumerate(train_loader)):
        real_img = img.to(config.device)

        if ii % 2 == 0:
            optimizer_discriminator.zero_grad()

            r_preds = discriminator(real_img)
            noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
            fake_img = generator(noises).detach()
            f_preds = discriminator(fake_img)

            r_f_diff = (r_preds - f_preds.mean()).clamp(max=1)
            f_r_diff = (f_preds - r_preds.mean()).clamp(min=-1)
            loss_d_real = (1 - r_f_diff).mean()
            loss_d_fake = (1 + f_r_diff).mean()
            loss_d = loss_d_real + loss_d_fake

            loss_d.backward()
            optimizer_discriminator.step()

        else:
            optimizer_generator.zero_grad()
            noises.data.copy_(torch.randn(config.batch_size, config.nz, 1, 1))
            fake_img = generator(noises)
            f_preds = discriminator(fake_img)
            r_preds = discriminator(real_img)
            r_f_diff = r_preds - torch.mean(f_preds)
            f_r_diff = f_preds - torch.mean(r_preds)
            loss_g = torch.mean(F.relu(1 + r_f_diff)) \
                     + torch.mean(F.relu(1 - f_r_diff))
            loss_g.backward()
            optimizer_generator.step()

    if epoch == config.epochs - 1:
        # 保存模型
        torch.save(discriminator.state_dict(), config.discriminator_path)
        torch.save(generator.state_dict(), config.generator_path)

print('Finished Training')

generator = Generator(config)
discriminator = Discriminator(config)

noises = torch.randn(config.gen_search_num,
                     config.nz, 1, 1).normal_(config.gen_mean,
                                                                     config.gen_std)
noises = noises.to(config.device)

generator.load_state_dict(torch.load(config.generator_path,
                                     map_location='cpu'))
discriminator.load_state_dict(torch.load(config.discriminator_path,
                                         map_location='cpu'))
generator.to(config.device)
discriminator.to(config.device)

fake_img = generator(noises)
scores = discriminator(fake_img).detach()

indexs = scores.topk(config.gen_num)[1]
result = []
for ii in indexs:
    result.append(fake_img.data[ii])

torchvision.utils.save_image(torch.stack(result), config.gen_img,
                             normalize=True, value_range=(-1, 1))

代码运行结果如下:

using 900 images for training.
28it [00:20,  1.40it/s]
28it [00:20,  1.33it/s]
28it [00:21,  1.29it/s]
…
28it [00:26,  1.06it/s]
Finished Training

效果图如图13-9所示,由于只训练了100个Epoch,因此图像生成的纹理还不算太清楚,大家计算资源允许的话,可以多训练一些Epoch来生成更多的图像细节。

图13-9

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1973746.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法板子:BFS(广度优先搜索)——迷宫问题,求从迷宫的起点到终点的最短路径; 八数码问题,求从初始布局到最终布局x最少移动多少次

目录 1. 核心思想在于bfs函数2. 代码中用到的数组的含义解释3. 迷宫问题(1)求从(0,0)点到(4,4)点的最短路径是多少——bfs函数(2)打印最短路径——在bfs函数的基础上多了一个print函数a. 思想b. 代码 4. 八数码问题——bfs函数 1.…

科普文:微服务之Spring Cloud Alibaba消息队列组件RocketMQ工作原理

概叙 本文探讨 RocketMQ 的事务消息原理,并从源码角度进行分析,以及事务消息适合什么场景,使用事务消息需要注意哪些事项。 同时详细介绍RocketMQ 事务消息的基本流程,并通过源码分析揭示了其内部实现原理,尽管事务消…

【avue+vue2+elementui】删除、rules、页面跳转、列表数据过长、日期dayjs

这里写目录标题 一、删除二、rules三、页面跳转四、列表数据过长截断五、日期 dayjs一、删除 🍃API/*** 删除.* @param {*} data * @returns 返参*/ export const deleteOrder = (data) => {return request({url: /api/Order/deleteOrder,method: post,data}) }HTML🍃左…

常见病症之中医药草一枝黄花

常见病症之中医药草一枝黄花 1. 源由2. 一枝黄花植物描述药用部分主要成分药理作用使用方法注意事项 3. 常用方剂3.1 一枝黄花汤3.2 一枝黄花解毒汤 4. 着凉感冒主要方剂加味处方使用方法注意事项 5. 补充资料 1. 源由 注:仅供参考,建议在中医师指导下使…

Unity【入门】小项目坦克大战

文章目录 1、开始场景1、场景装饰RotateObj 2、开始界面BasePanelBeginPanel 3、设置界面GameDataMgrSettingPanel 4、音效数据逻辑MusicData 5、排行榜界面RankPanel 6、排行榜数据逻辑RankInfo 7、背景音乐BKMusic 2、游戏场景1、游戏界面GamePanel 2、基础场景搭建CubeObjQu…

如何使用极狐GitLab CI/CD Component Catalog?【上】

极狐GitLab 是 GitLab 在中国的发行版,专门面向中国程序员和企业提供企业级一体化 DevOps 平台,用来帮助用户实现需求管理、源代码托管、CI/CD、安全合规,而且所有的操作都是在一个平台上进行,省事省心省钱。可以一键安装极狐GitL…

SQL进阶技巧:Hive如何巧解和差计算的递归问题?【应用案例2】

目录 0 问题描述 1 数据准备 2 问题分析 3 小结 0 问题描述 有如下数据:反应了每月的页面浏览量 现需要按照如下规则计算每月的累计阅读量,具体计算规则如下: 最终结果如下: 1 数据准备 with data as( select 2024-01 as month ,2 as pv union all select 2024-02 …

使用MongoDB构建AI:Jina AI将突破性开源嵌入模型变为现实

Jina AI创立于2020年,总部位于德国柏林,主要从事提示工程和嵌入模型业务,已迅速成长为多模态AI领导者。Jina AI积极推动开源和开放研究,致力于弥合先进AI理论与开发者及数据科学家构建的AI驱动型真实世界应用程序之间的差距。目前…

卷积神经网络 - 池化(Pooling)篇

序言 在深度学习的广阔领域中,卷积神经网络( CNN \text{CNN} CNN)以其卓越的特征提取能力,在图像识别、视频处理及自然语言处理等多个领域展现出非凡的潜力。而池化( Pooling \text{Pooling} Pooling)作为…

智慧水务项目(四)django(drf)+angular 18 配置REST_FRAMEWORK

一、说明 建立了几个文件 二、一步一步来 1、建立json_response.py 继承了 Response, 一共三个函数,成功、详情,错误 from rest_framework.response import Responseclass SuccessResponse(Response):"""标准响应成功的返回…

springboot共享汽车租赁管理系统-计算机毕业设计源码99204

目 录 第 1 章 引 言 1.1 选题背景及意义 1.2 研究前期调研 1.3 论文结构安排 第 2 章 系统的需求分析 2.1 系统可行性分析 2.1.1 技术方面可行性分析 2.1.2 经济方面可行性分析 2.1.3 法律方面可行性分析 2.1.4 操作方面可行性分析 2.2 系统功能需求分析 2.3 系统…

机械学习—零基础学习日志(高数18——无穷小与无穷大)

零基础为了学人工智能,真的开始复习高数 学习速度加快! 无穷小定义 这里可以记住,无穷小有一个特殊,那就是零。 零是最高阶的无穷小,且零是唯一一个常数无穷小。 张宇老师还是使用了超实数概念来讲解无穷小。其实是…

《马拉松名将手记:42.195公里的孤独之旅》大迫杰之舞

《马拉松名将手记:42.195公里的孤独之旅》大迫杰之舞 大迫杰,日本田径长跑选手。2020年3月1日,在东京马拉松赛上,以2小时5分29秒获得日本本土冠军,刷新自己保持的日本国家记录,并拿下东京奥运会马拉松项目入…

UE5 从零开始制作跟随的鸭子

文章目录 二、绑定骨骼三、创建 ControlRig四、创建动画五、创建动画蓝图六、自动寻路七、生成 goose八、碰撞 和 Physics Asset缺点 # 一、下载模型 首先我们需要下载一个静态网格体,这里我们可以从 Sketchfab 中下载:Goose Low Poly - Download Free …

黑暗之魂和艾尔登法环有什么联系吗 黑暗之魂和艾尔登法环哪一个好玩 苹果电脑怎么玩Windows游戏 apple电脑可以玩游戏吗

有不少游戏爱好者对于艾尔登法环与经典游戏黑魂之间是否存在关系产生了疑问。在新旧元素的融合中,艾尔登法环注定成为一场别具匠心的冒险之旅。在实机演示中类魂的玩法以及黑魂相似的画风让不少玩家想要了解本作与黑魂是否有联系,今天,我们将…

SAP生产版本维护以及注意事项

事务代码:C223/MM02,CS01,CA01 步骤: CS01创建BOM CA01创建工艺路线 C223/MM02创建生产版本 选择BOM清单 注意: 1、该生产版本与BOM清单绑定在一起,后续如果BOM有多个,需要添加或修改这个生产版本 2、…

进化版:一个C++模板工厂的编译问题的解决。针对第三方库的构造函数以及追加了的对象构造函数。牵扯到重载、特化等

原始版本在这里 一个C模板工厂的编译问题的解决。针对第三方库的构造函数以及追加了的对象构造函数。牵扯到重载、特化等-CSDN博客 问题 1、关于类型的判断&#xff0c;适应性不强 比如template <typename T>IsFarElementId<>&#xff0c;目前只能判断FarElemen…

达梦数据库的系统视图v$cacheitem

达梦数据库的系统视图v$cacheitem 达梦数据库的系统视图V$CACHEITEM的作用是显示缓存中的项信息&#xff0c;在 ini 参数 USE_PLN_POOL !0 时才统计。这个视图帮助数据库管理员监控和分析缓存的使用情况&#xff0c;优化数据库性能。通过查询V$CACHEITEM视图&#xff0c;可以获…

ai web 1.0靶机漏洞渗透详解

一、导入靶机 解压下载好的靶机&#xff0c;然后打开VMware&#xff0c;点击文件》打开》找到刚刚解压的靶机点击下面的文件》打开 确认是靶机的网络连接模式是NAT模式 二、信息收集 1、主机发现 在本机的命令窗口输入ipconfig查看VMnet8这块网卡&#xff0c;这块网卡就是虚…

案例分享—国外优秀ui设计作品赏析

国外UI设计创意迭出&#xff0c;融合多元文化元素&#xff0c;以极简风搭配动态交互&#xff0c;打造沉浸式体验&#xff0c;色彩运用大胆前卫&#xff0c;引领界面设计新风尚 同时注重用户体验的深度挖掘&#xff0c;通过个性化定制与智能算法结合&#xff0c;让界面不仅美观且…