PyTorch训练深度卷积生成对抗网络DCGAN

news2024/10/5 13:38:20

文章目录

    • DCGAN介绍
    • 代码
    • 结果
    • 参考

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:
在这里插入图片描述
图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
            self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1,
            ),
            nn.Tanh(),

        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.gen(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)

    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("success")
    
if __name__ == "__main__":
    test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:
在这里插入图片描述

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )

# comment mnist above and uncomment below if train on CelebA

# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

结果

训练5个epoch,部分结果如下:

Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

tensorboard --logdir=logs

打开tensorboard

在这里插入图片描述

参考

[1] DCGAN implementation from scratch
[2] https://arxiv.org/abs/1511.06434

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/893960.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

跨境电商平台风控揭秘:如何应对刷单风险?

跨境电商平台内部对比被举报的买家信息时&#xff0c;会进行一系列分析来确认是否存在刷评行为。系统会追溯买家的购买记录和留评记录&#xff0c;根据留评率等信息来判断是否存在刷评的行为。如果系统确认买家存在刷评行为&#xff0c;那么该买家曾经留下的所有评价都有可能被…

我国出租车行业的发展伪历史(依赖倒置)

一、前言 既然是“伪历史”&#xff0c;大家就暂且不要纠结故事的真实性了&#xff0c;因为我们今天主要讲的并非是中国出租车的发展史&#xff0c;而是希望通过这个伪历史的例子来用日常生活中的例子&#xff0c;来深入理解一下什么叫依赖倒置。 还是按照惯例&#xff0c;我…

【从零开始学习Linux】常用命令及操作

哈喽&#xff0c;哈喽&#xff0c;大家好~ 我是你们的老朋友&#xff1a;保护小周ღ 本期给大家带来的是 Linux 常用命令及操作&#xff0c;主要有三个分类&#xff1a;文件操作&#xff0c;目录操作&#xff0c;网络操作&#xff0c;创建文件 touch , 创建目录 mkdir , 删除…

【了解一下常见的设计模式】

文章目录 了解一下常用的设计模式(工厂、包装、关系)导语设计模式辨析系列 工厂篇工厂什么是工厂简单工厂「模式」&#xff08;Simple Factory「Pattern」&#xff09;简单工厂代码示例&#xff1a;简单计算器优点&#xff1a;缺点&#xff1a; 静态工厂模式特点&#xff1a; 工…

基于Spring Boot的社区诊所就医管理系统的设计与实现(Java+spring boot+MySQL)

获取源码或者论文请私信博主 演示视频&#xff1a; 基于Spring Boot的社区诊所就医管理系统的设计与实现&#xff08;Javaspring bootMySQL&#xff09; 使用技术&#xff1a; 前端&#xff1a;html css javascript jQuery ajax thymeleaf 微信小程序 后端&#xff1a;Java …

改进YOLO系列:2.添加ShuffleAttention注意力机制

添加ShuffleAttention注意力机制 1. ShuffleAttention注意力机制论文2. ShuffleAttention注意力机制原理3. ShuffleAttention注意力机制的配置3.1common.py配置3.2yolo.py配置3.3yaml文件配置1. ShuffleAttention注意力机制论文 论文题目:SA-NET: SHUFFLE ATTENTION …

教育行业选择CRM的四大要求

随着互联网教育的发展和变迁&#xff0c;越来越多的教育机构开始意识到管理客户关系的重要性。然而&#xff0c;对于教育行业来说&#xff0c;选择一款适合自己的CRM系统也不轻松。下面就来说说&#xff0c;教育行业crm要如何来选择&#xff1f; 一、明确使用需求 在进行CRM选…

如何使用 ChatGPT 将文本转换为 PowerPoint 演示文稿

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建可二次编辑的3D应用场景 步骤 1&#xff1a;将文本转换为幻灯片演示文稿 第一步涉及指示 ChatGPT 根据给定的文本生成具有特定数量幻灯片的演示文稿。首先&#xff0c;您必须向 ChatGPT 提供要转换的文本。 使用以下提示指示…

Gitlab服务部署及应用

目录 Gitlab简介 Gitlab工作原理 Gitlab服务构成 Gitlab环境部署 安装依赖包 启动postfix&#xff0c;并设置开机自启 设置防火墙 下载安装gitlab rpm包 修改配置文件/etc/gitlab/gitlab.rb&#xff0c;生产环境下可以根据需求修改 重新加载配置文件 浏览器登录Gitlab输…

Azure CLI 进行磁盘加密

什么是磁盘加密 磁盘加密是指在Azure中对虚拟机的磁盘进行加密保护的一种机制。它使用Azure Key Vault来保护磁盘上的数据&#xff0c;以防止未经授权的访问和数据泄露。使用磁盘加密&#xff0c;可以保护磁盘上的数据以满足安全和合规性要求。 参考文档&#xff1a;https://l…

职场修炼:性格内向的程序员如何突破自己

性格内向&#xff0c;不是缺点 社会常识中的看法&#xff1a;性格内心&#xff0c;是成功的障碍。 实际情况&#xff1a;内向和外向各有优缺点 忌讳&#xff1a; 强行改变自己的性格。内心不接受自己的性格。 内向者的优点 善于研究&#xff0c;能够较长时间研究技术问题…

仿写一个tomcat(含线程池配置)超详细!!

目录 工作原理 整体项目结构 自定义注解 创建servlet类 创建启动类 线程池配置 测试阶段 工作原理 首先看流程图&#xff0c;搞清楚tomcat的工作原理 工作原理如下&#xff1a; Tomcat使用一个叫作Catalina的核心组件来处理HTTP请求和响应。Catalina包含了一个HTTP连接…

匈牙利算法相关介绍

重要说明&#xff1a;本文从网上资料整理而来&#xff0c;仅记录博主学习相关知识点的过程&#xff0c;侵删。 一、参考资料 匈牙利算法匹配问题? Exactly how the Hungarian Algorithm works 多目标跟踪数据关联之匈牙利算法 五分钟小知识&#xff1a;什么是匈牙利算法 论文…

改进YOLO系列:3.添加SOCA注意力机制

添加SOCA注意力机制 1. SOCA注意力机制论文2. SOCA注意力机制原理3. SOCA注意力机制的配置3.1common.py配置3.2yolo.py配置3.3yaml文件配置1. SOCA注意力机制论文 暂未找到 2. SOCA注意力机制原理 3. SOCA注意力机制的配置 3.1common.py配置 ./models/common.p…

SpringBoot部署到腾讯云

SpringBoot部署到腾讯云 此处默认已经申请到腾讯云服务器&#xff0c;因为本人还没有申请域名&#xff0c;所以就直接使用的ip地址 XShell连接到腾讯云 主机中填写腾讯云的公网ip地址 公网ip地址在下图中找到 接下来填写服务器的用户名与密码 一般centOS用户名为root&#xff…

ZLMediakit-method ANNOUNCE failed: 401 Unauthorized

使用ffmpeg推流&#xff1a; nohup ffmpeg -stream_loop -1 -re -i "/usr/local/mp4/test.mp4" -vcodec h264 -acodec aac -f rtsp -rtsp_transport tcp rtsp://10.55.134.12/live/test &[rootlocalhost ~]# ffmpeg -stream_loop -1 -re -i "/usr/local/mp…

HCIP的交换机实验

题目 拓扑图 PC1/3接口用access 创建WLAN LSW1 创建WLAN [lsw1]vlan batch 2 to 6[lsw1-Ethernet0/0/1]p [lsw1-Ethernet0/0/1]port l [lsw1-Ethernet0/0/1]port link- [lsw1-Ethernet0/0/1]port link-flap [lsw1-Ethernet0/0/1]port link-type acc [lsw1-Ethernet0/0…

Java实现微信小程序V3支付 (完整demo)

1. 微信小程序支付-开发者文档https://pay.weixin.qq.com/wiki/doc/apiv3/apis/chapter3_5_1.shtml 2. 导入依赖 <!--小程序支付 v3--> <dependency><groupId>com.github.wechatpay-apiv3</groupId><artifactId>wechatpay-apache-httpclient<…

精密、CMOS、轨到轨输入/输出、宽带运算放大器MS8601/8602/8604

产品简述 MS8601/MS8602/MS8604 分别是单 / 双 / 四通道、轨到轨输入和输出、 单电源放大器&#xff0c;具有极低的失调电压和宽信号带宽。它采用 1.8V 至 5V 单电 源&#xff08; 0.9 V 至 2.5 V 双电源&#xff09;供电。 MS8601/MS8602/MS8604 低失调、极低的输入偏置…

02__models

LangChain提供两种封装的模型接口 1.大规模语言模型&#xff08;LLM&#xff09;&#xff1a;输入文本字符串&#xff0c;返回文本字符串 2.聊天模型&#xff1a;基于一个语言模型&#xff0c;输入聊天消息列表&#xff0c;返回聊天消息 Langchain的支持OpenAI、ChatGLM、Hu…