【Week-G7】Semi-Supervised GAN 实践,使用MNIST数据集

news2024/11/25 11:03:12

文章目录

  • 一、基础知识
  • 二、代码实现
    • 2.1 导入所需模块 & 设置网络初始参数
    • 2.2 初始化权重
    • 2.3 定义算法模型
    • 2.4 配置模型
    • 2.5 训练模型
    • 2.6 训练结果

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本次学习进行Semi-Supervised GAN的实践,数据集为MNIST
主要为了解惑:加入生成的图像本身就携带标签,比如数字1~9,那么:为什么还需要鉴别器判断输入图像的真假,而不直接判断图像属于0-9中的哪一个数字?

一、基础知识

本次学习使用到的SGAN将GAN扩展到半监督学习方式,通过强制判别器D来输出类别标签。具体结构如下图:
在这里插入图片描述

输入数据集:N类中某一个
生成器G:输出第N+1个类
判别器D:充当分类器C的效果
训练时:判别器D被用于预测输入时属于N+1类中的哪一个

SGAN可以用于训练效果更好的判别器D,并且比普通的GAN产生更加高质量的样本。
在这里插入图片描述

二、代码实现

2.1 导入所需模块 & 设置网络初始参数

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=50, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

cuda = True if torch.cuda.is_available() else False

2.2 初始化权重

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

2.3 定义算法模型


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)

        return validity, label

2.4 配置模型


# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

2.5 训练模型


# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
        fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/GAN/sgan/%d.png" % batches_done, nrow=5, normalize=True)

    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
        % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
    )

2.6 训练结果

下载MNIST数据集:
在这里插入图片描述
训练过程:
在这里插入图片描述
训练输出的图像:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1992622.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 Claude3.5 只需 2 分钟快速构建仪表盘

这是用 claude 生成的图表,只花了 2 分钟 目录 Claude何时使用Artifacts?我如何使用Artifacts?我的例子让claude导出本地部署结尾 关键还可以分享 这是分享之后的链接: https://claude.site/artifacts/1cf37377-1d00-4ab2-b8dd-a…

PMP考试一定要考到3A吗?怎么备考?

PMP(Project Management Professional)认证是全球公认的项目管理专业人士资格认证,它代表着项目管理领域的高水平标准。 在备考PMP考试时,有些赛宝关心是否需要考到3A(即三个领域均为Above Target,超出目标…

【中项】系统集成项目管理工程师-第10章 项目整合管理-10.7结束项目或阶段

前言:系统集成项目管理工程师专业,现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试,全称为“全国计算机与软件专业技术资格(水平)考试”&…

[Python学习日记-3] 编程前选择一个好用的编程工具

[Python学习日记-3] 编程前选择一个好用的编程工具 简介 PyCharm IDE的安装 PyCharm IDE安装后的一些常规使用 简介 在踏上 Python 编程的精彩旅程之前,选择一款得心应手的编程工具无疑是至关重要的一步。这就如同战士在出征前精心挑选趁手的武器,它将…

Unity补完计划 之 音效

本文仅作笔记学习和分享,不用做任何商业用途 本文包括但不限于unity官方手册,unity唐老狮等教程知识,如有不足还请斧正 首先,音频这块组件较少,但是内容很重要,因为对于任何一款非特殊面向人群的游戏来说&a…

SQLiteStudio 连接sqlite3数据库(真机数据库可视化调试)

SQLiteStudio安装 官网链接:https://sqlitestudio.pl/ 下载后,直接按部就班,打开即可使用 用户手册(工具如何使用直接看这份就可以了):https://github.com/pawelsalawa/sqlitestudio/wiki/User_Manual 其…

GoFly快速开发框架代码市场使用说明

说明 我们框架坚持开源的项目绝不能存在收费项目,所以我们gofly快速开发开源版没有内置代码仓插件,因此需要使用代码市场中的代码包需要再企业版中使用,代码市场插件如下: 图1、社区-代码市场​​​​ 他和企业版管理后台的代码仓…

Component和Loader

文章目录 文章内容效果图代码 文章内容 效果图 代码 import QtQuick 2.15 import QtQuick.Window 2.15 import FluentUI import QtQuick.Controls 2.5Window {visible: truewidth: 320height: 240// 自定义组件:需要手动加载Component{id:comRectangle{id:rectwidth: 80heigh…

关闭Windows安全中心

打开Windows安全中心的病毒和威胁防护。 打开该选项的管理设置。 关闭实时保护。

【RTOS面试题】RTOS和Linux的区别

实时操作系统(RTOS, Real-Time Operating System)与Linux操作系统(一种典型的普通操作系统,General-Purpose Operating System, GPOS)之间存在一些显著的区别。这两种操作系统各有侧重,适用于不同的应用场景…

循环执行时数据的同步方式

在dataX-web中循环执行时数据的同步方式 解决中文comment中文乱码 在mysql中 # (0)修改库注释 alter table DBS modify column desc varchar(256) character set utf8; alter table DATABASE_PARAMS modify column PARAM_VALUE varchar(256) characte…

用python创建极坐标平面

极坐标的介绍 http://t.csdnimg.cn/ucau3http://t.csdnimg.cn/ucau3这个文章里可以知道极坐标的基本知识,接下来实现极坐标的绘制 PolarPlane 是 Manim(一个用于数学动画的Python库)中的一个类,用于创建极坐标平面。与笛卡尔…

汇昌联信数字做拼多多运营怎么做?

在当今电商竞争激烈的环境下,如何有效地在拼多多这样的平台上进行运营,是许多商家和品牌都在思考的问题。汇昌联信数字作为一家致力于提供数字化解决方案的公司,其在拼多多上的运营策略值得深入探讨。本文将详细分析汇昌联信数字在拼多多上的…

【HBZ分享】Spring启动时核心refresh方法流程

refresh核心代码所在位置 在AbstractApplicationContext类中的refresh方法中 refresh的业务流程编排 调用obtainFreshBeanFactory()去创建一个全新的BeanFactory工厂,类型为DefaultListableBeanFctory,其功能为【解析xml】将里面bean标签内容解析成【…

信息学奥林匹克竞赛详解-CSP、NOIP、NOI、IOI是什么

近年来,随着计算机在教育领域的影响力越来越大,信息学奥林匹克竞赛也越来越受关注。 山东省在2017年秋季正式出版了《小学信息技术》,大幅度引入了Scratch、Python等编程语言。 浙江省在2018年的高考选考科目中新增了信息技术,包…

【Qt】图形化和纯代码实现Hello world的比较

本篇文章使用俩种方式实现Qt上的Hello world: 通过图形化的方式,在界面上创建出一个控件,显式Hello world通过纯代码的方式,通过编写代码,在界面上创建控件,显示Hello world 图形化方式 双击Forms文件中的…

CTFHUB-web-RCE-读取源代码

开启题目 网页发现了源代码,还是和前几题一样是 php:// ,提示说 flag 在代码中,并且在 /flag 文件夹中,题目名字也叫读取源代码。 php://filter 是一种元封装器,专门用于数据流的过滤和筛选。与传统的文件操作函数相比…

selenium的UI自动化框架入门

环境准备 python、pycharme、chromedriver google下载的官网地址 https://google.cn/chrome/ chromedriver chromedriver的下载 https://chromedriver.storage.googleapis.com/index.html chromedriver配置环境变量 C:\Users\Administrator\.cache\selenium\chromedrive…

Python的安装环境以及应用

1.环境python2,Python 最新安装3.12可以使用源码安装 查看安装包 [rootpython001 ~]# yum list installed | grep epel 3[rootpython001 ~]# yum list installed | grep python [rootpython001 ~]# yum -y install python3 安装python3 查看版本 [root…

【LLM大模型】中国人工智能系列白皮书--大模型技术

近期,中国人工智能学会发布了 《2023 中国人工智能系列白皮书–大模型技术(2023版)》,涵盖了大模型发展历程、技术概述、风险与挑战以及未来发展展望等。 👉CSDN大礼包🎁:全网最全《LLM大模型入…