使用pytorch构建一个无监督的深度卷积GAN网络模型

news2024/11/15 10:48:52

本文为此系列的第二篇DCGAN,上一篇为初级的GAN。普通GAN有训练不稳定、容易陷入局部最优等问题,DCGAN相对于普通GAN的优点是能够生成更加逼真、清晰的图像。
因为DCGAN是在GAN的基础上的改造,所以本篇只针对GAN的改造点进行讲解,其他还有不太了解的原理可以返回上一篇进行观看。

本文仍然使用MNIST手写数字数据集来构建一个深度卷积GAN(Deep Convolutional GAN)DCGAN,将使用卷积来替代全连接层,点击查看论文,generator的网络结构图如下:
在这里插入图片描述
DCGAN模型有以下特点:

  1. 判别器模型使用卷积步长取代了空间池化,生成器模型中使用反卷积操作扩大数据维度。
  2. 除了生成器模型的输出层和判别器模型的输入层,在整个对抗网络的其它层上都使用了Batch Normalization,原因是Batch Normalization可以稳定学习,有助于优化初始化参数值不良而导致的训练问题。
  3. 整个网络去除了全连接层,直接使用卷积层连接生成器和判别器的输入层以及输出层。
  4. 在生成器的输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在判别器上使用Leaky ReLU激活函数。

代码

model.py:

from torch import nn

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )
    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True)
            )
        else: # Final Layer
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh()
            )
    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)    # [b,c,h,w]
    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )
    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:  # Final Layer
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
            )
    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import *
torch.manual_seed(0) # Set for testing purposes, please do not change!


def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
batch_size = 1280
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002

beta_1 = 0.5
beta_2 = 0.999
device = 'cuda'

# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

n_epochs = 500
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        fake = gen(fake_noise)
        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Update gradients
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()

        ## Update generator ##
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        disc_fake_pred = disc(fake_2)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

每500个batch展示一次
每500个batch展示一次。
在这里插入图片描述
可以看到生成器的网络模型不再使用全连接,使用反卷积操作扩大数据维度;在输出层使用Tanh激活函数以控制输出范围,而在其它层中均使用了ReLU激活函数;在隐藏层中每层都使用BN来讲输出归到一定的范围内来稳定学习,使得后层的隐藏单元不过分依赖本层的隐藏单元,减弱内部协变量偏移,从而加速对特征的学习。
因为不再使用全连接而是使用卷积,所以输入的dimension变为channel,所以输入之前先改变noise的shape为(batch_size,channel,high,width)。
在这里插入图片描述
判别器的网络模型使用卷积代替的全连接,使用卷积操作减小数据维度;隐藏层中每层在激活之前使用BN。
在这里插入图片描述
对生成器和鉴别器的权重进行初始化,对于卷积层和转置卷积层(也就是反卷积层)使用正态分布来初始化权重(均值为0,标准差为0.02)的原因是为了确保权重的初始值具有适当的大小,并且不会过大或过小,从而避免梯度消失或梯度爆炸的问题。
对于BN化层,同样使用正态分布来初始化权重,同时将偏置项初始化为0。这是因为批归一化层在训练中通过调整均值和方差来规范化输入数据,因此初始的权重和偏置项都设置为较小的值,有助于加速网络的收敛。

下一篇构建WGAN_GP。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1552619.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Day53:WEB攻防-XSS跨站SVGPDFFlashMXSSUXSS配合上传文件添加脚本

目录 MXSS UXSS:Universal Cross-Site Scripting HTML&SVG&PDF&SWF-XSS&上传&反编译(有几率碰到) SVG-XSS PDF-XSS Python生成XSS Flash-XSS 知识点: 1、XSS跨站-MXSS&UXSS 2、XSS跨站-SVG制作&配合上传 3、XSS跨站-…

第18次修改了可删除可持久保存的前端html备忘录

第17次修改了可删除可持久保存的前端html备忘录&#xff1a;增加年月日星期&#xff0c;增加倒计时&#xff0c;更改保存区名称可以多个备忘录保存不一样的信息&#xff0c;匹配背景主题&#xff1a;现代深色 <!DOCTYPE html> <html lang"zh"> <head&…

Radio Silence for mac 好用的防火墙软件

Radio Silence for Mac是一款功能强大的网络防火墙软件&#xff0c;专为Mac用户设计&#xff0c;旨在保护用户的隐私和网络安全。它具备实时网络监视和控制功能&#xff0c;可以精确显示每个网络连接的状态&#xff0c;让用户轻松掌握网络活动情况。 软件下载&#xff1a;Radio…

低功耗、低成本 NAS/公共文件夹 的可能性

使用现状&#xff1a;多台工作电脑&#xff0c;家里人手一台&#xff0c;还在两个住处 有好几台工作电脑&#xff0c;不同电脑不同OS有不同的用途&#xff0c;最大的问题就是各个电脑上文件的同步问题&#xff0c;这里当然就需要局域网里的公共文件夹&#xff0c;在NAS的问题上…

探索数据库mysql--------------mysql主从复制和读写分离

目录 前言 为什么要主从复制&#xff1f; 主从复制谁复制谁&#xff1f; 数据放在什么地方&#xff1f; 一、mysql支持的复制类型 1.1STATEMENT&#xff1a;基于语句的复制 1.2ROW&#xff1a;基于行的复制 1.3MIXED&#xff1a;混合类型的复制 二、主从复制的工作过程 三个重…

掌握Flutter底部导航栏:畅游导航之旅

1. 引言 在移动应用开发中&#xff0c;底部导航栏是一种常见且非常实用的用户界面元素。它提供了快速导航至不同功能模块或页面的便捷方式&#xff0c;使用户可以轻松访问应用程序的各个部分。在Flutter中&#xff0c;底部导航栏也是一项强大的功能&#xff0c;开发者可以利用…

Linux用户及用户组权限

一、用户和用户组 功能项命令实例作用用户组cat /etc/group查看当前系统存在的用户组groupadd testing添加一个新的用户组testingcat /etc/group查看组是否被新增成功groupmod -n test testing将testing重命名成testgroupdel test删除组testgroups root查看用户root所在的所有…

tdesign坑之EnhancedTable树形结构默认展开所有行

⚠️在官方实例中&#xff0c;树形结构的表格提供了2种方法控制展开全部节点&#xff1a; 一是通过配置属性tree.defaultExpandAll为true代表默认展开全部节点&#xff08;仅默认情况有效&#xff09;&#xff1b; 二是使用组件实例方法expandAll()可以自由控制树形结构的展开…

“直播曝光“有哪些媒体直播分流资源?

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 我们线下举办活动时&#xff0c;往往希望活动进行更大的曝光&#xff0c;随着视频直播越来越被大众认可&#xff0c;甚至成了活动的标配&#xff0c;那么做活动视频直播的时候&#xff0…

R语言批量计算t检验,输出pvalue和均值

1.输入数据如下&#xff1a; 2.代码如下 setwd("E:/R/Rscripts/rG4相关绘图") # 读取CSV文件 data <- read.csv("box-cds-ABD-不同类型rg4-2.csv", stringsAsFactors FALSE)# 筛选出Type2列为指定五种类型的数据 filtered_data <- subset(data, …

LockSupport与线程中断机制

中断机制是个协商机制 Interrupt(): 将中断状态设置为true Interrupted():&#xff08;静态方法&#xff09; 1.返回当前线程的中断状态 2.将中断状态清零并设置为false is Interrupted(): 判断当前线程是否被中断 如何停止中断运行中的线程&#xff1f; 一个线程不应该由…

提取html工具封装和应用

提取html工具封装和应用 BeautifulSoup库和介绍BeautifulSoup使用BeautifulSoup重点方法BeautifulSoup其他方法 认证参数化实现创建json文件导包&#xff08;参数化&#xff09;编写测试用例技术难点--判断验证码不同 BeautifulSoup库和介绍 BeautifulSoup使用 1、导包 2、实例…

C# OpenCvSharp 轮廓检测

目录 效果 代码 下载 效果 代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Windows.Forms; using OpenCvSharp; using OpenCvSharp.…

kubernetes K8s的监控系统Prometheus升级Grafana,来一个酷炫的Node监控界面(二)

上一篇文章《kubernetes K8s的监控系统Prometheus安装使用(一)》中使用的监控界面总感觉监控的节点数据太少&#xff0c;不能快算精准的判断出数据节点运行的状况。 今天我找一款非常酷炫的多维度数据监控界面&#xff0c;能够非常有把握的了解到各节点的数据&#xff0c;以及运…

【js刷题:数据结构数组篇之长度最小的子数组】

长度最小的子数组 一、题目二、方法1.暴力解法2.滑动窗口是什么 滑动窗口的起始位置滑动窗口的结束位置 一、题目 给定一个含有 n 个正整数的数组和一个正整数 s &#xff0c;找出该数组中满足其和 ≥ s 的长度最小的 连续 子数组&#xff0c;并返回其长度。如果不存在符合条件…

移动端开发思考:Uniapp的上位替代选择

文章目录 前言跨平台开发技术需求技术选型uniappFlutterMAUIAvalonia安卓原生 Flutter开发尝试Avalonia开发测试测试项目新建项目代码MainViewMainViewModel 发布/存档 MAUI实战&#xff0c;简单略过打包和Avalonia差不多 总结 前言 作为C# .NET程序员&#xff0c;我有一些移动…

虚拟机-从头配置Ubuntu18.04(包括anaconda,cuda,cudnn,pycharm,ros,vscode)

最好先安装anaconda后cuda和cudnn&#xff0c;因为配置环境的时候可能conda会覆盖cuda的路径&#xff08;不确定这种说法对不对&#xff0c;这里只是给大家的建议&#xff09; 准备工作&#xff1a; 1.Ubuntu18.04&#xff0c;x86_64&#xff0c;amd64 虚拟机下载和虚拟机Ubu…

鸿蒙OS开发实例:【手撸服务卡片】

介绍 服务卡片指导文档位于“开发/应用模型/Stage模型开发指导/Stage模型应用组件”路径下&#xff0c;说明其极其重要。 本篇文章将分享实现服务卡片的过程和代码 准备 请参照[官方指导]&#xff0c;创建一个Demo工程&#xff0c;选择Stage模型 鸿蒙OS开发更多内容↓点击…

cookie,sessionStorage,localStorage的区别及应用场景、http状态码含义、使用token登录、无感登录

浏览器的缓存机制提供了可以将用户数据存储在客户端上的方式&#xff0c;可以利用cookie&#xff0c;session等跟服务端进行数据交互。 浏览器的存储方式有哪些&#xff1f; 1.cookiesH5标准前的本地存储方式兼容性好&#xff0c;请求头自带cookie存储量小&#xff0c;资源浪费…

Apache Hive的基本使用语法(二)

Hive SQL操作 7、修改表 表重命名 alter table score4 rename to score5;修改表属性值 # 修改内外表属性 ALTER TABLE table_name SET TBLPROPERTIES("EXTERNAL""TRUE"); # 修改表注释 ALTER TABLE table_name SET TBLPROPERTIES (comment new_commen…