人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用

news2024/9/20 20:25:12

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型11-pytorch搭建DCGAN模型,一种生成对抗网络GAN的变体实际应用,本文将具体介绍DCGAN模型的原理,并使用PyTorch搭建一个简单的DCGAN模型。我们将提供模型代码,并使用一些数据样例进行训练和测试。最后,我们将展示训练过程中的损失值和准确率。

文章目录:

  1. DCGAN模型简介
  2. DCGAN模型原理
  3. 使用PyTorch搭建DCGAN模型
  4. 数据样例
  5. 训练模型
  6. 测试模型
  7. 总结

1. DCGAN模型简介

DCGAN全称:Deep Convolutional Generative Adversarial Networks,它是一种生成对抗网络(GAN)的变体,它使用卷积神经网络(CNN)作为生成器和判别器。DCGAN在图像生成任务中表现出色,能够生成具有高分辨率和清晰度的图像。

2. DCGAN模型原理

DCGAN模型由两个部分组成:生成器(Generator)和判别器(Discriminator)。生成器负责生成图像,而判别器负责判断图像是否为真实图像。在训练过程中,生成器和判别器相互竞争,生成器试图生成越来越逼真的图像,而判别器试图更准确地识别生成的图像是否为真实图像。这个过程持续进行,直到生成器生成的图像足够逼真,以至于判别器无法区分生成的图像和真实图像。

DCGAN模型的数学原理表示:

生成器(Generator):

G ( z ) = x G(z) = x G(z)=x

其中, z z z是输入的随机噪声向量, x x x是生成的图像。

判别器(Discriminator):

D ( x ) = y D(x) = y D(x)=y

其中, x x x是输入的图像, y y y是判别器对图像的判断结果,表示图像是否为真实图像。

GAN的损失函数:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D,G) = \mathbb{E}{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1-D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中, p d a t a ( x ) p_{data}(x) pdata(x)表示真实数据的分, p z ( z ) p_z(z) pz(z)表示噪声向量的分布, D ( x ) D(x) D(x)表示判别器对图像 x x x的判断结果, G ( z ) G(z) G(z)表示生成器生成的图像, log ⁡ D ( x ) \log D(x) logD(x)表示判别器将真实图像判断为真实图像的概率, log ⁡ ( 1 − D ( G ( z ) ) ) \log(1-D(G(z))) log(1D(G(z)))表示判别器将生成图像判断为真实图像的概率。

在这里插入图片描述

3. 使用PyTorch搭建DCGAN模型

首先,我们需要导入所需的库:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset
from torch.autograd import Variable

接下来,我们定义生成器和判别器的网络结构:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # 输入是一个100维的向量
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 输出为(512, 4, 4)
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 输出为(256, 8, 8)
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 输出为(128, 16, 16)
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出为(3, 32, 32)
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 输入为(3, 32, 32)
            nn.Conv2d(3, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出为(128, 16, 16)
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出为(256, 8, 8)
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出为(512, 4, 4)
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

4. 数据样例

我们将使用CIFAR-10数据集进行训练。首先,我们需要对数据进行预处理:

if __name__ =="__main__":
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    trainset = dset.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

5. 训练模型

接下来,我们将训练DCGAN模型:

# 初始化生成器和判别器
netG = Generator()
netD = Discriminator()

# 设置损失函数和优化器
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 训练模型
num_epochs = 10

for epoch in range(num_epochs):
    for i, data in enumerate(trainloader, 0):
        # 更新判别器
        netD.zero_grad()
        real, _ = data
        batch_size = real.size(0)
        label = torch.full((batch_size,), 1)
        output = netD(real)
        errD_real = criterion(output, label)
        errD_real.backward()
        noise = torch.randn(batch_size, 100, 1, 1)
        fake = netG(noise)
        label.fill_(0)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        errD = errD_real + errD_fake
        optimizerD.step()

        # 更新生成器
        netG.zero_grad()
        label.fill_(1)
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        if i%5==0:
           # 打印损失值
           print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, num_epochs, i, len(trainloader), errD.item(), errG.item()))

6. 测试模型

训练完成后,我们可以使用生成器生成一些图像进行测试:

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

noise = torch.randn(64, 100, 1, 1)
fake = netG(noise)
imshow(torchvision.utils.make_grid(fake.detach()))

7. 总结

本文详细介绍了DCGAN模型的原理,并使用PyTorch搭建了一个简单的DCGAN模型。我们提供了模型代码,并使用CIFAR-10数据集进行训练和测试。最后,我们展示了训练过程中的损失值和生成的图像。希望本文能帮助您更好地理解DCGAN模型,并在实际项目中应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/638853.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java+openlayer实现大气污染扩散模拟反演

一、模拟参数及效果 二、应用背景 大气污染是当今社会面临的一个重要问题。随着工业化和城市化的进程,大气污染问题变得越来越严重。为了更好地应对这个问题,许多科学家和研究人员开始探索大气污染扩散反演技术。 大气污染扩散反演技术是一种通过数学模…

给软件测试人的一封信,全网最佳“指路明灯“

一、一招鲜吃遍天下 你需要有一个核心技能。这个技能至少达到远超你的同事(包括开发岗位的同事的)平均水平。最好达到业界领先水平,且这个核心技能需要不断打磨提高。比如,我选择的核心技能是使用Python写代码。这个核心技能可以…

3.2 基于Java配置类整合SSM框架实现用户登录

一、基于Java配置类整合SSM框架实现用户登录 1、创建Maven项目 Maven项目 - SSMLoginNew 单击【Finish】按钮 2、添加相关依赖 在pom.xml文件里添加相关依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache…

Kubernetes 1.27 加快 Pod 启动速度

如何在大型集群中加快节点上的 Pod 启动&#xff1f;这是企业集群管理员常常会面临的问题。 这篇博文重点介绍了从 kubelet 一侧加快 Pod 启动的方法。此方法不涉及通过 kube-apiserver 由 controller-manager 创建 Pod 所用的时间段&#xff0c;也不包含 Pod 的调度时间或在其…

电脑最牛逼的截图方式

1.电脑桌面上空白的地方新建一个文本文档&#xff0c;将后缀名修改为bat&#xff0c;截图如下&#xff1a; 2.右键点击该文档编辑&#xff0c;在编辑界面输入start snippingtool&#xff0c;点击保存之后关闭该文档。 3.双击该文档&#xff0c;在模式里面选择响应的截图方式即可…

MySQL IDE与pymysql模块

一、IDE工具介绍 生产环境还是推荐使用mysql命令行&#xff0c;但为了方便我们测试&#xff0c;可以使用IDE工具 在此我们推荐使用Navicat软件或pycharm来连接数据库,这样就能更详细直观地查询数据 掌握&#xff1a; #1. 测试链接数据库 #2. 新建库 #3. 新建表&#xff0c;新增…

2023 年程序员高考试卷!你能答对几个?

又是一年高考季&#xff0c;一起来做做“程序员们的高考试卷”&#xff0c;压压惊吧~ 2023年普通高等学校招生全国统一考试 程序员的高考试卷&#xff08;A卷&#xff09; 考生类别&#xff1a;码农 1、程序员A&#xff1a;借我1000元吧。 程序员B&#xff1a;给你凑个整数…

Linux基础知识点2

Linux基础知识 适合有Linux基础的人群进行复习。 禁止转载&#xff01; 文件管理与常用命令 Linux的文件的组成部分&#xff1a; 文件名、inode(i节点)和block(真正存数据的区域)。 查看某个文件的属性&#xff1a; ls -lh #可看到有类似”-rw-r--r--”的属性符号 …

轻松来自实力,亚马逊云科技助力边界智能应对业务高峰值数据考验

边界智能&#xff08;Bianjie.AI&#xff09;是2016年创立于上海的国家高新技术企业和专精特新企业&#xff0c;同时也是以香港为全球总部、服务全球的区块链技术创新团队。公司专注于区块链技术支持的下一代互联网应用服务&#xff0c;自主研发了跨多条联盟链的分布式应用服务…

STL入门 + 刷题(下)

&#x1f442; Raindrops (Intl. Version) - Katja Krasavice/Leony - 单曲 - 网易云音乐 &#x1f442; Rush E (Playable Version) - Sheet Music Boss - 单曲 - 网易云音乐 &#x1f442; 最美的瞬间 - 真瑞 - 单曲 - 网易云音乐 &#x1f442; 你可别卷了 - SipSu小口酥…

CTF Crypto --- orz!

文章目录 题目解题过程 题目 from Crypto.Util.number import * from gmpy2 import *flag bxxx t len(flag)//3 part1 bytes_to_long(flag[:t]) part2 bytes_to_long(flag[t:2*t]) part3 bytes_to_long(flag[2*t:]) q getPrime(1024) p next_prime(q) n p * qo getPr…

面试必备,29个Java面试必考点、1000多道Java面试题

马上金九银十招聘旺季就到了&#xff0c;不知道大家是否准备好了&#xff0c;面对金九银十的招聘旺季&#xff0c;如果没有精心准备那笔者认为那是对自己不负责任&#xff1b;就我们Java程序员来说&#xff0c;多数的公司总体上面试都是以自我介绍项目介绍项目细节/难点提问基础…

文献下载神器:文献党下载器使用方法

文献党下载器是一款文献资源整合平台&#xff0c;把知网、万方、维普、超星/读秀、Web of Science、Elsevier&#xff08;ScienceDirect&#xff09;、Wiley 、SpringerLink、EI&#xff08;工程索引&#xff09;、IEEE&#xff08;电气电子工程师学会&#xff09;、Taylor &am…

ESP32-C3系列模组简介

ESP32-C3是一款安全稳定、低功耗、低成本的物联网芯片&#xff0c;搭载RISC-V 32位单核处理器&#xff0c;为物联网产品提供行业领先的射频性能、完善的安全机制和丰富的内存资源。 嵌入式智能终端、无线WIFI技术以及Internet的广泛应用必将使家居控制变得更加自动化、智能化和…

基础软件加速自主创新,openGauss成就业务“新箭头”

不久前&#xff0c;想必业界都注意到了MetaERP横空出世的消息。作为企业经营的核心系统&#xff0c;MetaERP突破外部封锁&#xff0c;实现完全自研替代&#xff0c;是华为有史以来牵涉面较广、复杂性较高的项目。这其实是国产基础软件迅速崛起的一个缩影。 基础软件产业是关系…

element-plus vue 错误汇总

input 无法输入0.01 element ui input 无法输入0.01 一输出0.0就报错&#xff0c;是因为写成了v-model.number&#xff0c;改成v-model即可。 <el-input v-model.number"formData.reduceMoney"class"input200"type"number"focus"discou…

C#开发的OpenRA游戏之建造物品的窗口1

C#开发的OpenRA游戏之建造物品的窗口1 前面已经分析了基地工程车的创建和移动,当玩家把基地工程车移动到合适的位置,就会进行部署基地,也即是选择一个离矿场比较近的位置,因为这样做可以提高采矿的速度,减少采矿车的运输时间。 接着下来,虽然基地是建立了,但是还需要创…

HybridCLR 最佳实践,老项目集成热更(战棋项目)

文本介绍了老项目使用HybridCLR 集成热更的过程 从项目结构调整&#xff0c;代码调整&#xff0c;打包&#xff0c;热更测试&#xff0c;跑完HybridCLR所有流程 先看效果&#xff08;安卓&#xff09; 源码及资料领取方式私信&#xff1a;领取资料&#xff1a;HybridCLR战棋热更…

2023亚马逊云科技中国峰会引领无服务器架构新潮流:Serverlesspresso Workshop

序言 在今年3月&#xff0c;我有幸接触了一个项目&#xff0c;也因此结识了 亚马逊云科技无服务器架构 Serverless。在陆续了解 Amazon 产品的过程中&#xff0c;我逐渐发现它所带给我的惊喜远远超出了最初的预期。 今天&#xff0c;想向大家介绍一个名为 Serverlesspresso Wor…

京东数据分析:2023年Q1京东奶粉品牌销量排行榜

近几年我国新生人口数量不断下降。尽管国家大力推进多胎政策&#xff0c;但奶粉的市场需求量依然有明显下滑&#xff0c;导致国内奶粉行业的发展低迷&#xff0c;今年Q1依然没有回弹的迹象。 根据鲸参谋数据显示&#xff0c;今年Q1奶粉在京东平台销量2000万件&#xff0c;同比下…