生成对抗网络——GAN深度卷积实现(代码+理解)

news2024/11/24 7:05:49

        本篇博客为 上篇博客的 另一个实现版本,训练流程相同,所以只实现代码,感兴趣可以跳转看一下。

  生成对抗网络—GAN(代码+理解)

http://t.csdnimg.cn/HDfLOicon-default.png?t=N7T8http://t.csdnimg.cn/HDfLO


目录

一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

2. 模型训练时

3. 优化器定义

4. 训练数据

5. 模型结构

(1)生成器        

(2)判别器


一、GAN深度卷积实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码实现

import torch
import torch.nn as nn
import argparse
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np


parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

# 加载数据
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./others/",
        train=False,
        download=False,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02) # 给定均值和标准差的正态分布N(mean,std)中生成值
        torch.nn.init.constant_(m.bias.data, 0.0)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = opt.img_size // 4  # 原为28*28,现为32*32,两边各多了2
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),    # 调整数据的分布,使其 更适合于 下一层的 激活函数或学习
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        # 下采样(图片进行 4次卷积操作,变为ds_size * ds_size尺寸大小)
        ds_size = opt.img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

# 实例化
generator = Generator()
discriminator = Discriminator()

# 初始化参数
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

# 交叉熵损失函数
adversarial_loss = torch.nn.BCELoss()

def gen_img_plot(model, epoch, text_input):
    prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16])
    plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')
    plt.show()

# ----------
#  Training
# ----------
D_loss_ = []  # 记录训练过程中判别器的损失
G_loss_ = []  # 记录训练过程中生成器的损失
for epoch in range(opt.n_epochs):
    # 初始化损失值
    D_epoch_loss = 0
    G_epoch_loss = 0
    count = len(dataloader)  # 返回批次数
    for i, (imgs, _) in enumerate(dataloader):
        valid = torch.ones(imgs.shape[0], 1)
        fake = torch.zeros(imgs.shape[0], 1)

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        z = torch.randn(imgs.shape[0], opt.latent_dim)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        # batches_done = epoch * len(dataloader) + i
        # if batches_done % opt.sample_interval == 0:
        #     save_image(gen_imgs.data[:25], "others/images/%d.png" % batches_done, nrow=5, normalize=True)

        # 累计每一个批次的loss
        with torch.no_grad():
            D_epoch_loss += d_loss
            G_epoch_loss += g_loss

        # 求平均损失
    with torch.no_grad():
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss_.append(D_epoch_loss.item())
        G_loss_.append(G_epoch_loss.item())

        text_input = torch.randn(opt.batch_size, opt.latent_dim)
        gen_img_plot(generator, epoch, text_input)


x = [epoch + 1 for epoch in range(opt.n_epochs)]
plt.figure()
plt.plot(x, G_loss_, 'r')
plt.plot(x, D_loss_, 'b')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['G_loss','D_loss'])
plt.show()

3. 运行结果展示

二、学习中产生的疑问,及文心一言回答

1. 模型初始化

        函数 weights_init_normal 用于初始化 模型参数,为什么要 以 均值和标准差 的正态分布中采样的数 为标准?

2. 模型训练时

        这里“d_loss = (real_loss + fake_loss) / 2” 中的 “/ 2” 操作,在 实际训练中 有什么作用?

        由(real_loss + fake_loss) / 2的 得到 的 d_loss 与(real_loss+fake_loss)得到的 d_loss 进行 回溯,两者结果会 有什么不同吗?

3. 优化器定义

        设置 betas=(opt.b1, opt.b2) 有什么 实际的作用?通俗易懂的讲一下

        betas=(opt.b1, opt.b2) 是怎样 更新学习率的?

4. 训练数据

        这里我们用的data为 MNIST,为什么img_size设置为 32,不是 28?

5. 模型结构

(1)生成器        

        解释一下为什么是“Upsample, Conv2d, BatchNorm2d, LeakyReLU ”这种顺序?

(2)判别器

        模型的 基本 运算步骤是什么?其中为什么需要 “Dropout2d( p=0.25, inplace=False)”这一步?

        关于“ds_size” 和 “128 * ds_size ** 2”的实际意义?


                                后续更新 GAN的其他模型结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1835916.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

苍穹外卖---导入接口文档

一、前后端分离开发流程 第一步:定义接口,确定接口的路径、请求方式、传入参数、返回参数。 第二步:前端开发人员和后端开发人员并行开发,同时,也可自测。 第三步:前后端人员进行连调测试。 第四步&…

搭建zookeeper、Kafka集群

搭建zookeeper、Kafka集群 1、绘制kafka的存储结构、副本机制2、搭建zookeeper集群3、搭建kafka集群4、使用kafka创建名为自己姓名汉语拼音的topic5、查看topic的分区和副本策略 1、绘制kafka的存储结构、副本机制 2、搭建zookeeper集群 实验环境准备: 3台服务器&…

界面追踪方法Level Set与VOF在气泡流动模拟的效果比较

对于两相流模拟,模型主要分为两大类:高相分数模型和界面捕捉类模型。当我们关注水中的含气量(气泡界面及气泡形状可忽略),则采用高相分数模型,此模型适用于气泡特别多的流动问题。对于有明确边界的流体&…

AI安全水深流急,黄铁军首谈AGI能力与风险分级,2024智源大会圆满落幕

2024年6月15日,为期 2 天的北京智源大会圆满落下帷幕。本次大会围绕大语言模型、多模态模型、Agent、具身智能、数据新基建、AI系统、AI开源、AI for Science、AI安全等人工智能热门技术方向和焦点议题,召开了20平行论坛,共计百场报告。 过去…

Linux top 命令使用教程

转载请标明出处:https://blog.csdn.net/donkor_/article/details/139775547 文章目录 一、top 是什么二、top的基础语法三、top输出信息解读 一、top 是什么 Linux top 是一个在Linux和其他类 Unix 系统上常用的实时系统监控工具。它提供了一个动态的、交互式的实时…

基于JSP的房屋租赁系统

开头语: 你好,我是专注于计算机科学与技术研究的学长。如果你对房屋租赁系统感兴趣或有相关开发需求,欢迎联系我。 开发语言:Java 数据库:MySQL 技术:JSPJavaBeansServlet 工具:MyEclipse、…

gRPC(Google Remote Procedure Call Protocol)谷歌远程过程调用协议

文章目录 1、gRPC简介2、gRPC核心的设计思路3、gPRC与protobuf关系 1、gRPC简介 gPRC是由google开源的一个高性能的RPC框架。Stubby Google内部的RPC,演化而来的,2015年正式开源。云原生时代是一个RPC标准。 2、gRPC核心的设计思路 网络通信 ---> gPR…

VM4.3 二次开发04 方案输出结果设置

方案输出结果设置,这个设置是为了在二次开发的上位机软件中显示我们想要的数据,和在二开中如何获取这些结果。 打开方案点下如中的图标。 打开如下图。 再点点红色圈出来的图标,打开参数设置界面。 输出设置可以要输出的数据和参数名称。点上…

【Linux】程序地址空间之动态库的加载

我们先进行一个整体轮廓的了解,随后在深入理解细节。 在动态库加载之前还要说一下程序的加载,因为理解了程序的加载对动态库会有更深的理解。 轮廓: 首先,不管是程序还是动态库刚开始都是在磁盘中的,想要执行对应的可…

隧道代理是什么?怎么运作的?

隧道代理作为网络代理的一种形式,已经在现代互联网世界中扮演着重要的角色。无论是保护隐私、访问受限网站还是实现网络流量的安全传输,隧道代理都发挥着重要作用。在本文中,我们将深入探讨隧道代理的概念、运作方式以及在不同场景中的应用。…

ClickHouse 高性能的列式数据库管理系统

ClickHouse是一个高性能的列式数据库管理系统(DBMS),主要用于在线分析处理查询(OLAP)。以下是对ClickHouse的详细介绍: 基本信息: 来源:由俄罗斯的Yandex公司于2016年开源。全称&…

在向量数据库中存储多模态数据,通过文字搜索图片

在向量数据中存储多模态数据,通过文字搜索图片,Chroma 支持文字和图片,通过 OpenClip 模型对文字以及图片做 Embedding。本文通过 Chroma 实现一个文字搜索图片的功能。 OpenClip CLIP(Contrastive Language-Image Pretraining&…

Eigen中 Row-Major 和 Column-Major 存储顺序的区别

Eigen中 Row-Major 和 Column-Major 存储顺序的区别 flyfish Eigen::RowMajor 是 Eigen 库中用于指定矩阵存储顺序的一种选项 理解 Row-Major 和 Column-Major 存储顺序的区别,绘制一个单一的图来显示内存中的元素访问顺序,在图中用箭头表示访问顺序. import nu…

【无重复字符的最长子串】

无重复字符的最长字串 一、题目二、解决方法1.暴力解法2.滑动窗口哈希 三、总结1.es6 new set()的用法添加元素add()删除元素delete()判断元素是否存在has 2.滑动窗口和双指针的联系和特点 一、题目 二、解决方法 1.暴力解法 解题思路:使用两层循环逐个生成子字符串…

Ardupilot开源代码之ExpressLRS性能实测方法

Ardupilot开源代码之ExpressLRS性能实测方法 1. 源由2. 测试效果3. 测试配置4. 总结5. 参考资料6. 补充 1. 源由 之前一直在讨论ExpressLRS性能的问题,有理论、模拟、实测。 始终缺乏完整的同一次测试的测试数据集,本章节将介绍如何在Ardupilot上进行获…

聆思CSK6大模型+AI交互多模态开源SDK介绍

视觉语音大模型 AI 开发套件( CSK6-MIX )是围绕 CSK6011A 芯片设计的具备丰富语音图像功能与硬件外设的开发板,采用具备丰富组件生态的 Zephyr RTOS作为操作系统,官方提供了十几种开源SDK,包含大模型语音交互、大模型拍照识图、文生图、人脸识…

spark常见问题

写文章只是为了学习总结或者工作内容备忘,不保证及时性和准确性,看到的权当个参考哈! 1. 执行Broadcast大表时,等待超时异常(awaitResult) 现象:org.apache.spark.SparkException: Exception…

设置角色运动的动画

(1) 打开Assets-UnityTechnologies-Animation-Animators,Create-Animation-Controller,命名为JohnLemon (2) 打开JohnLemon,出现下图 (3) 依次将Assets-UnityTechnologies-Animation-Animation中的JohnIdle和JohnWalk拖放到Base Layer窗口中 (4) 右击Idl…

整合JavaSSM框架【超详细】

在整合SSM之前我们首先要知道SSM框架指的是哪些框架? Java的SSM指的是Spring、Spring MVC、MyBatis这三个框架 Spring框架 什么是Spring? Spring是一个支持快速开发Java EE应用程序的框架。它提供了一系列底层容器和基础设施,并可以和大量常…

win11右键小工具

开头要说的 在日常使用场景中,大家如果用的是新的笔记本电脑,应该都是安装的win11系统, 当然win11系统是最被诟病的, 因为有很多人觉得很难操作, 就比如一个小小的解压操作, 在win7和win10上&#xff…