第G7周:Semi-Supervised GAN 理论与实战

news2024/12/24 12:51:55

 🍨 本文为🔗365天深度学习训练营 中的学习记录博客

  🍦 参考文章:365天深度学习训练营-第G7周:Semi-Supervised GAN 理论与实战(训练营内部成员可读)

  🍖 原作者:K同学啊|接辅导、项目定制

 🏡 运行环境:
电脑系统:Windows 10
语言环境:python 3.10
编译器:Pycharm 2022.1.1
深度学习环境:Pytorch  


目录

一、理论知识讲解

二、代码实现

1、配置代码 

 2、初始化权重

3、定义算法模型

4、配置模型

 5、训练模型


一、理论知识讲解

该算法将产生式对抗网络(GAN) 拓展到半监督学习,通过强制判别器D来输出类别标签。我们
在一个数据集上训练一个生成器G以及一个判别器D,输入是N类当中的一个。在训练的时候,判别器D被用于预测输入是属于N+1类中的哪一个,这个N+1是对应了生成器G的输出,这里的判别器
D同时也充当起了分类器C的效果。这种方法可以用于训练效果更好的判别器D,并且可以比普通的GAN产性更加高质量的样本。Semi-Supervised GAN有如下优点:
(1)作者对GANs做了一个新的扩展,允许它同时学习一个生成模型和一个分类器。我们把这个 扩展叫做半监督GAN或SGAN
(2)论文实验结果表明,SGAN在有限数据集比没有生成部分的基准分类器提升了分类性能
(3)论文实验结果表明,SGAN可以显著地提升生成样本的质量并降低生成器的训练时间。 

二、代码实现

1、配置代码 
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--num_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args(args=[])
print(opt)

cuda = True if torch.cuda.is_available() else False
Namespace(n_epochs=2, batch_size=64, lr=0.0002, b1=0.5, b2=0.999, n_cpu=2, latent_dim=100, num_classes=10, img_size=32, channels=1, sample_interval=400)
 2、初始化权重
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)
3、定义算法模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.num_classes, opt.latent_dim)

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

        # Output layers
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.num_classes + 1), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)

        return validity, label
4、配置模型
# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ../../data/mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ../../data/mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../../data/mnist\MNIST\raw
 5、训练模型
# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
        fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # Loss for fake images
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) / 2

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        # Calculate discriminator accuracy
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
        % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
    )
[Epoch 0/2] [Batch 937/938] [D loss: 1.358861, acc: 50%] [G loss: 0.671799]
[Epoch 1/2] [Batch 937/938] [D loss: 1.343094, acc: 50%] [G loss: 0.681119]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1168161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

非递归方法实现二叉树前、中、后序遍历

文章目录 非递归实现二叉树前、中、后序遍历一、非递归实现前序遍历1.思路2.代码 二、非递归实现二叉树的中序遍历1.思路2.代码 三、非递归实现二叉树的后序遍历1.思路2.代码 非递归实现二叉树前、中、后序遍历 一、非递归实现前序遍历 1.思路 前序遍历的顺序是 :根…

【Linux】僵尸进程、孤儿进程的理解与验证

僵尸进程 概念 僵尸进程(Zombie Process)是指一个已经终止执行的子进程,但其父进程尚未调用 wait() 或 waitpid() 函数来获取子进程的退出状态。 Linux 中,僵尸进程会保留一些资源,如进程 ID、进程表项和一些系统资源…

C++11 initializer_list 轻量级初始化列表的使用场景(让自定义类可以用初始化列表的形式来实例化对象)

initializer_list 是 C11 中的一个特性&#xff0c;它允许你使用花括号 {} 中的值列表来初始化容器或数组。通常用于初始化标准库容器&#xff0c;比如 std::vector、std::set、std::map 以及数组。 场景一&#xff1a;用初始化列表初始化容器 std::vector<int> arr {…

JavaScript基础知识点速通

0 前言 本文是近期我学习JavaScript网课的笔记&#xff0c;一是方便自己速查回忆&#xff0c;二是希望帮到同样有需求的朋友们。 1 介绍 1.1 基本情况 JavaScript是一种编程语言&#xff0c;运行在客户端&#xff08;浏览器&#xff09;上&#xff0c;实现人机交互效果&…

【扩散模型】不同组件搭积木,获得新模型

学习地址&#xff1a; https://github.com/huggingface/diffusion-models-class/tree/main/unit3 VAE The Tokenizer and Text Encoder UNet In-Painting 例如&#xff1a;基于contrlnet做的校徽转图片

vue项目使用vite设置proxy代理,vite.config.js配置,解决本地跨域问题

vue3vite4项目&#xff0c;配置代理实现本地开发跨域问题 非同源请求&#xff0c;也就是协议(protocol)、端口(port)、主机(host)其中一项不相同的时候&#xff0c;这时候就会产生跨域 vite的proxy代理和vue-cli的proxy大致相同&#xff0c;需要在vite.config.js文件中配置&…

一、Linux开机、重启、关机和用户登录注销

1.【关机】 shutdown shutdown now 表示立即关机 shutdown -h now 表示立即关机 shutdown -h 1 表示1分钟后关机 halt 用来关闭正在运行的Linux操作系统 2.【重启】 shutdown -r now 表示立即重启 reboot 重启系统 sync …

设计模式之装饰模式--优雅的增强

目录 概述什么是装饰模式为什么使用装饰模式关键角色基本代码应用场景 版本迭代版本一版本二版本三—装饰模式 装饰模式中的巧妙之处1、被装饰对象和装饰对象共享相同的接口或父类2、当调用装饰器类的装饰方法时&#xff0c;会先调用被装饰对象的同名方法3、子类方法与父类方法…

中国人民大学与加拿大女王大学金融硕士—重要的是,你一直在努力

人虽然生下来就分三六九等&#xff0c;不同的人过着不同的生活&#xff0c;我的生活没办法选择&#xff0c;我只能尽我所能的让自己变得优秀。中国人民大学与加拿大女王大学金融硕士是我们无论怎样都可以变优秀的优质渠道。V13146152701 那么我们为什么要读研&#xff0c;读研…

串口通信代码整合1

本文为博主 日月同辉&#xff0c;与我共生&#xff0c;csdn原创首发。希望看完后能对你有所帮助&#xff0c;不足之处请指正&#xff01;一起交流学习&#xff0c;共同进步&#xff01; > 发布人&#xff1a;日月同辉,与我共生_单片机-CSDN博客 > 欢迎你为独创博主日月同…

OLE DB 访问接口所需的(最大)数据长度为 18,但返回的数据长度为 6。

sqlserver查询oracle链接服务器视图,报错 给最终返回的字符串进行类型转换,字符串大小按返回值最大的那个oracle源本字段类型长度 aaaaaa AS yljgbmcast(aaaaaa AS varchar(10)) AS yljgbm

Proteus仿真--1602LCD显示仿手机键盘按键字符(仿真文件+程序)

本文主要介绍基于51单片机的1602LCD显示仿手机键盘按键字符&#xff08;完整仿真源文件及代码见文末链接&#xff09; 仿真图如下 其中左下角12个按键模拟仿真手机键盘&#xff0c;使用方法同手机键一样&#xff0c;长按自动跳动切换键值&#xff0c;松手后确认选择&#xff…

API接口加密,解决自动化中登录问题

一、加密方式 AES&#xff1a;对称加密&#xff0c;快RAS&#xff1a;非对称加密&#xff0c;慢AESRAS&#xff1a;安全高效 加密过程&#xff1a;字符串》字节流》加密的字节流&#xff08;算法&#xff09;&#xff0c;解密有可能出现乱码&#xff0c;所以不能直接转成字符…

python+selenium自动化测试--鼠标悬停浮窗定位

页面上有些元素会隐藏起来&#xff0c;要鼠标放到某个位置才会显示出来&#xff0c;例如百度首页https://www.baidu.com/设置下面的隐藏按钮&#xff0c;如下图所示 定位鼠标悬停才显示的元素&#xff0c;要引入新模块&#xff0c;如下所示 from selenium.webdriver.common.ac…

12.JavaScript(WebAPI) - JS api文献精解

文章目录 1.WebAPI 背景知识1.1什么是 WebAPI1.2什么是 API1.3API 参考文档 2.DOM 基本概念2.1什么是 DOM2.2DOM 树 3.获取元素3.1querySelector3.2querySelectorAll 4.事件初识4.1基本概念4.2事件三要素4.3简单示例 5.操作元素5.1获取/修改元素内容5.1.1innerText5.1.2innerHT…

AD9371 官方例程裸机SW 和 HDL配置概述(二)

AD9371 系列快速入口 AD9371ZCU102 移植到 ZCU106 &#xff1a; AD9371 官方例程构建及单音信号收发 ad9371_tx_jesd -->util_ad9371_xcvr接口映射&#xff1a; AD9371 官方例程之 tx_jesd 与 xcvr接口映射 AD9371 官方例程 时钟间的关系与生成 &#xff1a; AD9371 官方…

javafaker测试数据生成实战

javafaker测试数据生成实战 1.背景2.介绍2.1 特点 3. 使用3.1 基础使用3.1.1 maven依赖3.1.1 使用示例 3.2 进阶使用3.1 生成中文信息3.2 根据姓名生成账号3.2.1 maven依赖3.2.2 中文转拼音工具类 3.3 高级使用3.3.1 中文性名重复处理方案1: 偷懒方式方案2: 较真模式 1.背景 最…

ChatGPT 被爆重大隐私泄露!在回答时突然蹦出陌生男子自拍照,你的数据都将被偷走训练模型!

ChatGPT 被爆重大隐私泄露 &#xff01; 一位用户在向 ChatGPT 询问 Python 中的代码格式化包 black 的用法时&#xff0c;没有一点点防备&#xff0c;ChatGPT 在回答中插入了一个陌生男子的自拍照&#xff08;手动捂脸.jpg&#xff09; 可以看到刚开始 ChatGPT 还相当正常&am…

智慧灯杆网关智能化选择(网关助力城市完整项目方案)

在当代城市发展中&#xff0c;智慧照明作为一项重要的技术创新&#xff0c;正逐渐改变着我们的城市生活。作为城市智慧照明的核心设备&#xff0c;智慧灯杆网关SG600凭借出色的性能和创新的解决方案&#xff0c;成为了引领城市智慧照明的完美选择。本文将详细介绍SG600的特点和…

linux centos7安装colmap

centos安装colmap 一、安装依赖 sudo yum install \gflags-devel \glog-devel \glew-devel \atlas \atlas-devel \lapack-devel \blas-devel \flann-devel \lz4-devel \sqlite-devel \metis-devel \qt5-qtbase-devel二、编译安装colmap git clone https://github.com/colmap/…