第G9周:ACGAN理论与实战

news2025/4/5 17:06:12
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

上一周已经给出代码,需要可以跳转上一周的任务
第G8周:ACGAN任务

import argparse
import os
import numpy as np

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch

# 创建用于存储生成图像的目录
os.makedirs("images", exist_ok=True)

# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="训练的总轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每个批次的大小")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的一阶动量衰减")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的二阶动量衰减")
parser.add_argument("--n_cpu", type=int, default=4, help="用于批次生成的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--n_classes", type=int, default=10, help="数据集的类别数")
parser.add_argument("--img_size", type=int, default=32, help="每个图像的尺寸")
parser.add_argument("--channels", type=int, default=1, help="图像通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="图像采样间隔")
opt = parser.parse_args()
print(opt)

# 检查是否支持GPU加速
cuda = True if torch.cuda.is_available() else False


# 初始化神经网络权重的函数
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


# 生成器网络类
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # 为类别标签创建嵌入层
        self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)

        # 计算上采样前的初始大小
        self.init_size = opt.img_size // 4  # Initial size before upsampling

        # 第一层线性层
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        # 卷积层块
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        # 将标签嵌入到噪声中
        gen_input = torch.mul(self.label_emb(labels), noise)

        # 通过第一层线性层
        out = self.l1(gen_input)

        # 重新整形为合适的形状
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)

        # 通过卷积层块生成图像
        img = self.conv_blocks(out)
        return img


# 判别器网络类
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # 定义判别器块的函数
        def discriminator_block(in_filters, out_filters, bn=True):
            """返回每个判别器块的层"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        # 判别器的卷积层块
        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # 下采样后图像的高度和宽度
        ds_size = opt.img_size // 2 ** 4

        # 输出层
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)

        return validity, label


# 损失函数
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor


# 保存生成图像的函数
def sample_image(n_row, batches_done):
    """保存从0到n_classes的生成数字的图像网格"""
    # 采样噪声
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    # 为n行生成标签从0到n_classes
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
# 训练
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # 真实数据的标签
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        # 生成数据的标签
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # 配置输入
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        # 训练生成器
        # -----------------

        optimizer_G.zero_grad()

        # 采样噪声和标签作为生成器的输入
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))

        # 生成一批图像
        gen_imgs = generator(z, gen_labels)

        # 损失度量生成器的欺骗判别器的能力
        validity, pred_label = discriminator(gen_imgs)
        g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        # 训练判别器
        # ---------------------

        optimizer_D.zero_grad()

        # 真实图像的损失
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # 生成图像的损失
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

        # 判别器的总损失
        d_loss = (d_real_loss + d_fake_loss) / 2

        # 计算判别器的准确率
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
        )
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1632911.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是视频号小店?小店怎么做?详细玩法流程来了

大家好,我是电商笨笨熊 视频号小店成了今年电商市场又一热门项目; 作为腾讯推出的电商,不少人曾说过,视频号小店会成为下一个风口; 那么视频号小店到底是什么,值得投入吗,又该怎么做呢&#…

[华为OD]C卷 给定一个数组,数组中的每个元素代表该位置的海拔高度 山脉的个数 200

题目: 给定一个数组,数组中的每个元素代表该位置的海拔高度。0表示平地,>1时表示属于某个 山峰,山峰的定义为当某个位置的左右海拔均小于自己的海拔时,该位置为山峰。数组起始位 置计算时可只满足一边…

AI助力后厨可视化智慧监管,让“舌尖安全”看得见

一、背景与需求分析 夏天是食物易腐败的季节,高温容易引发食品安全问题。在后厨环境中,食品安全问题可能涉及食品加工、后厨环境、食品是否被污染等方面,而不合格的食品安全管理可能导致食品中毒事件等风险,损害消费者的健康和餐…

偏微分方程算法之五点菱形差分法

目录 一、研究目标 二、理论推导 三、算例实现 四、结论 一、研究目标 上个专栏我们介绍了双曲型偏微分方程的主要算法及实现。从今天开始,我们在新的专栏介绍另一种形式偏微分方程-椭圆型的解法。 研究目标选取经典的二维椭圆型方程(也称泊松Poisso…

半导体制造工艺之分类浅述

半导体制造工艺分为逻辑制程(也叫逻辑工艺)和特殊制程(也叫特色工艺)。 1、逻辑工艺概述 随着集成电路行业沿着摩尔定律不断发展,晶体管数量增加的同时,工艺节点不断缩小。先进逻辑工艺是相对的概念,2005年全球先进逻辑工艺的工艺节点在65/55纳米,现在则变为3纳米。中…

好好聊一聊:Agent AI智能体的未来|TodayAI

​​​​​​​ 一、 引言 在当今时代,人工智能(AI)技术的快速发展正不断改变着我们的生活与工作方式。尤其是Agent AI智能体,作为AI技术中的一种重要形式,它们通过模拟人类智能行为来执行各种复杂任务,从…

sCrypt全新上线RUNES功能

sCrypt智能合约平台全新上线一键etch/mint RUNES功能! 请访问 https://runes.scrypt.io/ 或点击阅读原文体验! 关于sCrypt sCrypt是BSV区块链上的一种智能合约高级语言。比特币使用基于堆栈的Script语言来支持智能合约,但是用原生Script编…

多猫家庭吐血总结!这样选冻干真不踩雷?这几款主食冻干喂出貌美小猫

315中国之声的报道曝光了河北省邢台市南和区某宠粮代工厂的“行业秘密”,这让许多宠物主人感到震惊和不安。配料表上标明的鸡肉含量和新鲜鸡小胸含量看似可观,但背后却是用鸡肉粉替代的真相。我们养宠物是为了增添生活的乐趣,然而这些行业乱象…

实验案例二:配置Trunk,实现相同VLAN的跨交换机通信

1.实验环境 公司的员工人数已达到100人,其网络设备如图12.13所示。现在的网络环境导致广播较多 网速慢,并且也不安全。公司希望按照部门划分网络,并且能够保证一定的网络安全性 其网络规划如下: PC1和 PC3为财务部,属于 VLAN 2&…

Linux驱动开发——(九)platform设备驱动

目录 一、Linux驱动的分离 二、Linux驱动的分层 三、platform平台驱动模型简介 3.1 platform_driver结构体 3.2 device_driver结构体 3.3 platform驱动API函数 四、驱动代码 一、Linux驱动的分离 对于Linux这种庞大而复杂的系统,需要非常注重代码的重用性&a…

Docker-容器的前世今生

文章目录 Docker为什么产生?硬件虚拟化硬件虚拟化解决的问题硬件虚拟化定义硬件虚拟化技术虚拟机的优点虚拟机的缺点 操作系统虚拟化即容器容器化解决的问题容器化定义容器化技术历史 容器和虚拟机对比 Docker的发展历史Docker架构客户端服务端仓库Registry Docker重…

JavaEE 初阶篇-深入了解特殊文件(Properties 属性文件、XML)

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 Properties 属性文件概述 1.1 Properties 属性文件特性与作用 1.2 使用 Properties 把键值对数据写出到属性文件中 1.3 使用 Properties 读取属性文件里的键值对数…

JMeter 请求头信息配置详解

在进行 Web 测试和 API 测试时,正确配置 HTTP 请求头是关键步骤之一,尤其当使用诸如 JMeter 这样的强大工具时。在本文中,我将详细介绍如何在 JMeter 中有效地配置和管理HTTP请求头。 在 JMeter 中添加和配置 HTTP 请求头 步骤 1: 打开 HTT…

【Redis 开发】多级缓存,本地进程缓存Caffeine

多级缓存 多级缓存本地进程缓存CaffeineCaffeine三种缓存驱逐策略 多级缓存 Redis处理并发的能力是非常强大的,但是tomcat的支持并发的能力跟不上Redis的性能,导致整体性能的下降 Redis缓存失效时,会对数据库产生冲击,之间再无屏…

0425DormAJAX项目

0425DormAJAX项目包-CSDN博客 数据库字段 添加界面: 初始状态: 点击性别,宿舍号使用ajax动态添加: 学生主界面: 实现分页查询: 点击修改学生宿舍,查看换寝记录,ajax动态显示列表&…

如何在WordPress中设置网站的SEO标题和描述

在WordPress中,想要让你的网站在搜索引擎结果中脱颖而出,设置优秀的SEO标题和描述至关重要。这不仅可以帮助搜索引擎更好地理解你的网站内容,还可以吸引更多的点击率和流量。而选择一款合适的SEO插件是实现这一目标的关键之一。让我们来看看两…

速成AWD并获奖的学习方法和思考记录

前言 这是一个市赛。之前没有怎么打过AWD,所以进入决赛后只有三天的准备时间,期间我不停的请教大佬,阅读各类文章,受益颇深,做此纪录,奉献给掌控的各位同学学习。 在AWD中本是三人一队,可惜我…

Babel 原理浅析

Babel 原理浅析 Babel 是什么Babel 的作用及常用场景Babel 执行过程原理Babel的基本原理解析过程插件系统 Babel 是什么 官方解释:Babel 是一个 JavaScript 编译器,也是一个工具链,主要用于将 ECMAScript 2015 代码转换为当前和旧版浏览器或环…

一篇文章 学会Qt 样式表(qss)

QML 中风格和主题的设计可以通过配置文件选择现有几种中的一种,或者直接在控件定义时,指定其属性,如背景颜色或者字体大小。在QWidget框架中,则通过了一种叫做qss样式表的东西来进行描述,跟CSS逻辑上类似。 这个qss抽…

ThreeJs模拟工厂生产过程八

这节算是给这个车间场景收个尾,等了几天并没有人发设备模型给我,只能自己找了一个凑合用了。加载模型之前,首先要把货架上的料箱合并,以防加载模型之后因模型数量多出现卡顿,方法和之前介绍的合并传送带方法相同&#…