人工智能算法工程师(高级)课程5-图像生成项目之对抗生成模型与代码详解

news2024/11/23 10:29:07

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(高级)课程5-图像生成项目之对抗生成模型与代码详解。本文将介绍对抗生成模型(GAN)及其变体CGAN、DCGAN的数学原理,并通过PyTorch框架搭建完整可运行的代码,帮助读者掌握图像生成的原理和技术。

文章目录

  • 一、GAN模型
    • 1. GAN模型的数学原理
    • 2. GAN模型的代码实现
  • 二、CGAN模型
    • 1. CGAN模型的数学原理
    • 2. CGAN模型的代码实现
  • 三、DCGAN模型
    • 1. DCGAN模型的数学原理
    • 2. DCGAN模型的代码实现
  • 四、总结

一、GAN模型

1. GAN模型的数学原理

生成对抗网络(GAN)由Goodfellow等人在2014年提出,主要由生成器(Generator)和判别器(Discriminator)两部分组成。生成器的任务是生成尽可能接近真实数据的样本,而判别器的任务是将生成器生成的样本与真实样本区分开来。
GAN的目标函数如下:
V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1-D(G(z)))] V(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
其中, D ( x ) D(x) D(x)表示判别器判断 x x x为真实样本的概率, G ( z ) G(z) G(z)表示生成器生成的样本, z z z为随机噪声。
在这里插入图片描述

2. GAN模型的代码实现

首先,导入所需库:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

定义生成器:

class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )
    def forward(self, z):
        return self.model(z)

定义判别器:

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

训练GAN模型:

# 设定超参数
z_dim = 100
img_dim = 28 * 28
batch_size = 64
lr = 0.0002
epochs = 50
# 加载数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
G = Generator(z_dim, img_dim)
D = Discriminator(img_dim)
# 定义优化器
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)
# 训练过程
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # 训练判别器
        real_imgs = imgs.view(-1, img_dim)
        z = torch.randn(batch_size, z_dim)
        fake_imgs = G(z)
        D_real = D(real_imgs)
        D_fake = D(fake_imgs)
        D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))
        optimizer_D.zero_grad()
        D_loss.backward()
        optimizer_D.step()
        # 训练生成器
        z = torch.randn(batch_size, z_dim)
        fake_imgs = G(z)
        D_fake = D(fake_imgs)
        G_loss = -torch.mean(torch.log(D_fake))
        optimizer_G.zero_grad()
        G_loss.backward()
        optimizer_G.step()
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(dataloader)}], D_loss: {D_loss.item():.4f}, G_loss: {G_loss.item():.4f}')

二、CGAN模型

1. CGAN模型的数学原理

条件生成对抗网络(CGAN)在GAN的基础上,为生成器和判别器引入了额外的条件信息 y y y。CGAN的目标函数如下:

V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ∣ y ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ∣ y ) ) ) ] V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x|y)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1-D(G(z|y)))] V(D,G)=Expdata(x)[logD(xy)]+Ezpz(z)[log(1D(G(zy)))]
其中, D ( x ∣ y ) D(x|y) D(xy)表示在给定条件 y y y的情况下,判别器判断 x x x为真实样本的概率, G ( z ∣ y ) G(z|y) G(zy)表示在给定条件 y y y的情况下,生成器生成的样本。

在这里插入图片描述

2. CGAN模型的代码实现

CGAN的生成器和判别器需要在原有的基础上加入条件信息 y y y。以下是CGAN的代码实现:

class CGenerator(nn.Module):
    def __init__(self, z_dim, condition_dim, img_dim):
        super(CGenerator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim + condition_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )
    def forward(self, z, y):
        z_y = torch.cat([z, y], 1)
        return self.model(z_y)
class CDiscriminator(nn.Module):
    def __init__(self, img_dim, condition_dim):
        super(CDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim + condition_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x, y):
        x_y = torch.cat([x, y], 1)
        return self.model(x_y)

训练CGAN模型时,需要将条件信息 y y y传递给生成器和判别器:

# 假设条件信息y的维度为10
condition_dim = 10
# 初始化条件生成器和条件判别器
CG = CGenerator(z_dim, condition_dim, img_dim)
CD = CDiscriminator(img_dim, condition_dim)
# 定义优化器
optimizer.CG = optim.Adam(CG.parameters(), lr=lr)
optimizer.CD = optim.Adam(CD.parameters(), lr=lr)
# 训练过程
for epoch in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        # 将标签转换为one-hot编码
        y = torch.nn.functional.one_hot(labels, num_classes=condition_dim).float()
        real_imgs = imgs.view(-1, img_dim)
        z = torch.randn(batch_size, z_dim)
        fake_imgs = CG(z, y)
        CD_real = CD(real_imgs, y)
        CD_fake = CD(fake_imgs, y)
        CD_loss = -torch.mean(torch.log(CD_real) + torch.log(1 - CD_fake))
        optimizer.CD.zero_grad()
        CD_loss.backward()
        optimizer.CD.step()
        z = torch.randn(batch_size, z_dim)
        fake_imgs = CG(z, y)
        CD_fake = CD(fake_imgs, y)
        CG_loss = -torch.mean(torch.log(CD_fake))
        optimizer.CG.zero_grad()
        CG_loss.backward()
        optimizer.CG.step()
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(dataloader)}], CD_loss: {CD_loss.item():.4f}, CG_loss: {CG_loss.item():.4f}')

三、DCGAN模型

1. DCGAN模型的数学原理

深度卷积生成对抗网络(DCGAN)是GAN的一种变体,它将卷积神经网络(CNN)应用于生成器和判别器。DCGAN的目标函数与GAN相同,但其网络结构有所不同,使得生成器和判别器能够更好地处理图像数据。
在这里插入图片描述

2. DCGAN模型的代码实现

以下是DCGAN的生成器和判别器的代码实现:

class DCGenerator(nn.Module):
    def __init__(self, z_dim, img_channels):
        super(DCGenerator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.model(z)
class DCDiscriminator(nn.Module):
    def __init__(self, img_channels):
        super(DCDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(img_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, img):
        return self.model(img).view(img.size(0), -1)
# 设定超参数
z_dim = 100
img_channels = 1
img_size = 64
batch_size = 64
lr = 0.0002
epochs = 50
# 加载数据集
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
DG = DCGenerator(z_dim, img_channels)
DD = DCDiscriminator(img_channels)
# 定义优化器
optimizer_DG = optim.Adam(DG.parameters(), lr=lr)
optimizer_DD = optim.Adam(DD.parameters(), lr=lr)
# 训练过程
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # 训练判别器
        real_imgs = imgs.view(-1, img_channels, img_size, img_size)
        z = torch.randn(batch_size, z_dim, 1, 1)
        fake_imgs = DG(z)
        DD_real = DD(real_imgs)
        DD_fake = DD(fake_imgs)
        DD_loss = -torch.mean(torch.log(DD_real) + torch.log(1 - DD_fake))
        optimizer_DD.zero_grad()
        DD_loss.backward()
        optimizer_DD.step()
        # 训练生成器
        z = torch.randn(batch_size, z_dim, 1, 1)
        fake_imgs = DG(z)
        DD_fake = DD(fake_imgs)
        DG_loss = -torch.mean(torch.log(DD_fake))
        optimizer_DG.zero_grad()
        DG_loss.backward()
        optimizer_DG.step()
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(dataloader)}], DD_loss: {DD_loss.item():.4f}, DG_loss: {DG_loss.item():.4f}')

在上述代码中,DCGenerator使用了多个ConvTranspose2d层来逐步增加图像的尺寸,而DCDiscriminator则使用了多个Conv2d层来逐步减小图像的尺寸。每个卷积层后面都跟有批量归一化(BatchNorm)和激活函数(ReLU或LeakyReLU),这些是DCGAN的关键组成部分,有助于稳定训练过程。

四、总结

本文介绍了GAN、CGAN和DCGAN的数学原理,并通过PyTorch框架提供了完整的代码实现。通过这些代码,读者可以深入了解对抗生成模型的训练过程,并掌握图像生成的原理和技术。在实际应用中,这些模型可以用于图像合成、风格迁移、数据增强等多个领域。需要注意的是,训练GAN模型可能会遇到稳定性问题,因此在实际操作中可能需要调整超参数和模型结构以获得最佳效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942408.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android 15 之如何快速适配 16K Page Size

在此之前,我们通过 《Android 15 上 16K Page Size 为什么是最坑》 介绍了: 什么是16K Page Size为什么它对于 Android 很坑如何测试 如果你还没了解,建议先去了解下前文,然后本篇主要是提供适配的思路,因为这类适配…

算法——滑动窗口(day7)

904.水果成篮 904. 水果成篮 - 力扣(LeetCode) 题目解析: 根据题意我们可以看出给了我们两个篮子说明我们在开始采摘到结束的过程中只能有两种水果的种类,又要求让我们返回收集水果的最大数目,这不难让我们联想到题目…

Java 面试相关问题(中)——并发编程相关问题

这里只会写Java相关的问题,包括Java基础问题、JVM问题、线程问题等。全文所使用图片,部分是自己画的,部分是自己百度的。如果发现雷同图片,联系作者,侵权立删。 1 基础问题1.1 什么是并发,什么是并行&#…

Python爬虫知识体系-----Urllib库的使用

数据科学、数据分析、人工智能必备知识汇总-----Python爬虫-----持续更新:https://blog.csdn.net/grd_java/article/details/140574349 文章目录 1. 基本使用2. 请求对象的定制3. 编解码1. get请求方式:urllib.parse.quote()2. ur…

数驭未来,景联文科技构建高质大模型数据库

国内应用层面的需求推动AI产业的加速发展。根据IDC数据预测,预计2026年中国人工智能软件及应用市场规模会达到211亿美元。 数据、算法、算力是AI发展的驱动力,其中数据是AI发展的基石,中国的数据规模增长速度预期将领跑全球。 2024年《政府工…

【WAF剖析】10种XSS某狗waf绕过姿势,以及思路分析

原文:【WAF 剖析】10 种 XSS 绕过姿势,以及思路分析 xss基础教程参考:https://mp.weixin.qq.com/s/RJcOZuscU07BEPgK89LSrQ sql注入waf绕过文章参考: https://mp.weixin.qq.com/s/Dhtc-8I2lBp95cqSwr0YQw 复现 网站安全狗最新…

[数据集][目标检测]野猪检测数据集VOC+YOLO格式1000张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):1000 标注数量(xml文件个数):1000 标注数量(txt文件个数):1000 标注…

如何查看jvm资源占用情况

如何设置jar的内存 java -XX:MetaspaceSize256M -XX:MaxMetaspaceSize256M -XX:AlwaysPreTouch -XX:ReservedCodeCacheSize128m -XX:InitialCodeCacheSize128m -Xss512k -Xmx2g -Xms2g -XX:UseG1GC -XX:G1HeapRegionSize4M -jar your-application.jar以上配置为堆内存4G jar项…

Web前端:HTML篇(二)元素属性

HTML 属性 属性是 HTML 元素提供的附加信息。 HTML 元素可以设置属性属性可以在元素中添加附加信息属性一般描述于开始标签属性总是以名称/值对的形式出现&#xff0c;比如&#xff1a;name"value"。 属性实例 HTML 链接由 <a> 标签定义。链接的地址在 href …

如何开启或者关闭 Windows 安全登录?

什么是安全登录 什么是 Windows 安全登录呢&#xff1f;安全登录是 Windows 附加的一个组件&#xff0c;它可以在用户需要登录的之前先将登录界面隐藏&#xff0c;只有当用户按下 CtrlAltDelete 之后才出现登录屏幕&#xff0c;这样可以防止那些模拟登录界面的程序获取密码信息…

来聊聊redis集群数据迁移

写在文章开头 本文将是笔者对于redis源码分析的一个阶段的最后一篇&#xff0c;将从源码分析的角度让读者深入了解redis节点迁移的工作流程&#xff0c;希望对你有帮助。 Hi&#xff0c;我是 sharkChili &#xff0c;是个不断在硬核技术上作死的 java coder &#xff0c;是 CS…

JavaScript青少年简明教程:赋值语句

JavaScript青少年简明教程&#xff1a;赋值语句 赋值语句&#xff08;assignment statement&#xff09; JavaScript的赋值语句用于给变量、对象属性或数组元素赋值。赋值语句的基本语法是使用符号 () 将右侧的值&#xff08;称为“源操作数”&#xff09;赋给左侧的变量、属…

Docker Minio rclone数据迁移

docker minio进行数据迁移 使用rclone进行数据迁移是一种非常灵活且强大的方式&#xff0c;特别是在处理大规模数据集或跨云平台迁移时。rclone是一款开源的命令行工具&#xff0c;用于同步文件和目录到多种云存储服务&#xff0c;包括MinIO。下面是使用rclone进行数据迁移至Mi…

【RT摩拳擦掌】RT600 4路音频同步输入1路TDM输出方案

【RT摩拳擦掌】RT600 4路音频同步输入1路TDM输出方案 一&#xff0c; 文章简介二&#xff0c;硬件平台构建2.1 音频源板2.2 音频收发板2.3 双板硬件连接 三&#xff0c;软件方案与软件实现3.1 方案实现3.2 软件代码实现3.2.1 4路I2S接收3.2.2 I2S DMA pingpong配置3.2.3 音频数…

卧室激光投影仪推荐一下哪款效果最好?当贝X5S亮度卧室开灯照样清晰

现在家庭卧室装投影仪也不是什么稀奇的事情了&#xff0c;外面客厅看电视机&#xff0c;里面卧室投影仪直接投白墙各有各的优势。躺在卧室的床上&#xff0c;看超大屏投影真的很惬意。卧室投影的品类比较多&#xff0c;有些价格便宜的投影宣传说卧室看很适合&#xff0c;其实不…

设计模式12-构建器

设计模式12-构建器 由来和动机原理思想构建器模式的C代码实现构建器模式中的各个组件详解1. 产品类&#xff08;Product&#xff09;2. 构建类&#xff08;Builder&#xff09;3. 具体构建类&#xff08;ConcreteBuilder&#xff09;4. 指挥者类&#xff08;Director&#xff0…

实战:OpenFeign使用以及易踩坑说明

OpenFeign是SpringCloud中的重要组件&#xff0c;它是一种声明式的HTTP客户端。使用OpenFeign调用远程服务就像调用本地方法一样&#xff0c;但是如果使用不当&#xff0c;很容易踩到坑。 Feign 和OpenFeign Feign Feign是Spring Cloud组件中的一个轻量级RESTful的HTTP服务客…

rabbitmq生产与消费

一、rabbitmq发送消息 一、简单模式 概述 一个生产者一个消费者模型 代码 //没有交换机&#xff0c;两个参数为routingKey和消息内容 rabbitTemplate.convertAndSend("test1_Queue","haha");二、工作队列模式 概述 一个生产者&#xff0c;多个消费者&a…

C4D2024软件下载+自学C4D 从入门到精通【学习视频教程全集】+【素材笔记】

软件介绍与下载&#xff1a; 链接&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1n8cripcv6ZTx4TBNj5N04g?pwdhfg5 提取码&#xff1a;hfg5 基础命令的讲解&#xff1a; 掌握软件界面和基础操作界面。学习常用的基础命令&#xff0c;如建模、材质、灯光、摄像机…

设计模式-领域逻辑模式-结构映射模式

对象和关系之间的映射&#xff0c;关键问题在于二者处理连接的方式不同。 表现出两个问题&#xff1a; 表现方法不同。对象是通过在运行时&#xff08;内存管理环境或内存地址&#xff09;中保存引用的方式来处理连接的&#xff0c;关系数据库则通过创建到另外一个表的键值来处…