G7 - Semi-Supervised GAN 理论与实战

news2024/11/28 20:42:02
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 理论知识
  • 模型实现
    • 引用、配置参数
    • 初始化权重
    • 定义算法模型
    • 模型配置
    • 模型训练
    • 训练模型
  • 模型效果
  • 总结与心得体会


理论知识

在条件GAN中,判别器只用来判断图像的真和假,到了条件GAN中,图像本身其实是有标签的,这时候我们可能会想,为什么不直接让判别器输出图像的标签呢?本节要探究的SGAN就实现了这样一个GAN网络。

SGAN将GAN拓展到半监督学习,通过强制判别器D来输出类别标签来实现。

SGAN在一个数据集上训练一个生成器G和一个判别器D,输入是N类中的一个,在训练的时候,判别器D也被用于预测输入是属于N+1类中的哪一个,这个N+1是对应了生成器G的输出,这里的判别器D同时也充当起了分类器C的效果。

经过实验发现,这种方法可以用于训练效果更好的判别器D,并且可以比普通的GAN产生更加高质量的样本。

Semi-Supervised GAN有如下成果:

  • 作者对GANs做了一个新的扩展,允许它同时学习一个生成模型和一个分类器,我们把这个拓展称为半监督GAN或者SGAN
  • 实验结果表明,SGAN在有限数据集上比没有生成部分的基准分类器提升了分类性能
  • 实验结果表明,SGAN可以显著地提升生成样本的质量并降低生成器的训练时间
    模型判别器工作示意
    效果对比
    对比生成效果发现,SGAN比普通的DCGAN算法的结果更好。

模型实现

引用、配置参数

import argparse
import os
import numpy as np
import math

from torchvision import datasets, transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

# 创建结果输出目录,没有就新增,有就跳过
os.makedirs('images', exist_ok=True)

# 参数
n_epochs = 50 # 训练轮数
batch_size = 64 # 每个批次的样本数量
lr = 0.0002 # 学习率
b1 = 0.5 # Adam优化器的第一个动量衰减参数
b2 = 0.999 # Adam 优化器的第二个动量衰减参数
n_cpu = 8 # 用于批次生成的CPU线程数
latent_dim = 100 # 潜在空间的维度
num_classes = 10 # 数据集的类别数
img_size = 32 # 每个图像的尺寸(高度和宽度相等)
channels = 1 # 图像的通道数(灰度图像通道数为1)
sample_interval = 400 # 图像采样间隔

# 全局设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

初始化权重

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv')!= -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.normal_(m.bias.data, 0.0)

定义算法模型

# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 创建一个标签嵌入层,用于将条件标签映射到潜在空间
        self.label_emb = nn.Embedding(num_classes, latent_dim)

        # 初始化图像尺寸, 用于上采样之前
        self.init_size = img_size //4

        # 第一个全连接层,将随机噪声映射到合适的维度
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128*self.init_size**2))

        # 生成器的卷积块
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise):
        out = self.l1(noise)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img
# 判别器,一个分类网络
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def discriminator_block(in_filters, out_filters, bn=True):
            """返回每个鉴别器块的层"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
            
        # 鉴别器的卷积块
        self.conv_blocks = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # 下采样图像的高度和宽度
        ds_size = img_size // 2 ** 4

        # 输出层
        self.adv_layer = nn.Sequential(nn.Linear(128*ds_size**2, 1),nn.Sigmoid())
        self.aux_layer = nn.Sequential(nn.Linear(128*ds_size**2, num_classes + 1), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)
        return validity, label

模型配置

# 定义损失函数

# 二元交叉熵损失,用于对抗训练
adversarial_loss = nn.BCELoss().to(device)
# 交叉熵损失,用于辅助分类
auxiliary_loss = nn.CrossEntropyLoss().to(device)

# 初始化生成器和鉴别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 初始化模型权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# 配置数据加载器
os.makedirs('data/mnist', exist_ok=True)

dataloader = DataLoader(
    datasets.MNIST(
        'data/mnist',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ]),
    ),
    batch_size=batch_size,
    shuffle=True)

# 创建优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

模型训练

for epoch in range(opt._epochs):
	for i, (imgs, labels) in enumerate(dataloader):
		batch_size = imgs.shape[0]
		# 生成对抗训练的标签
		valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False
		fake = Variable(FloatTensor(batch_size, 2).fill_(0,0)
		fake_aux_gt = Variable(LongTensor(batch_size).fill_(opt.num_classes, requires_grad)

训练模型

for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):
        batch_size = imgs.shape[0]

        # 定义对抗训练的标签
        valid = torch.ones((batch_size, 1), requires_grad=False, device=device)
        fake = torch.zeros((batch_size, 1), requires_grad=False, device=device)
        fake_aux_gt= torch.ones((batch_size), dtype=torch.int64, device=device, requires_grad=False)*num_classes

        # 配置输入数据
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # ** 训练生成器 **

        optimizer_G.zero_grad()

        # 采样噪声和类别标签作为生成器的输入
        z = torch.rand([batch_size, latent_dim], device=device)
        
        # 生成一批图像
        gen_imgs = generator(z)

        # 计算生成器的损失衡量生成器欺骗鉴别器的能力
        validity, _ = discriminator(gen_imgs)
        g_loss = adversarial_loss(validity, valid)
        
        g_loss.backward()
        optimizer_G.step()

        # ** 训练鉴别器 **
        optimizer_D.zero_grad()

        # 真实图像的损失
        real_pred, real_aux = discriminator(real_imgs)
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

        # 生成图像的损失
        fake_pred, fake_aux = discriminator(gen_imgs.detach())
        d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, fake_aux_gt)) /2

        # 总的鉴别器损失
        d_loss = (d_real_loss + d_fake_loss) / 2

        # 计算鉴别器准确率
        pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy(), fake_aux_gt.data.cpu().numpy()], axis=0)
        d_acc = np.mean(np.argmax(pred, axis=1) == gt)

        d_loss.backward()
        optimizer_D.step()

        batches_done = epoch*len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)

    print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]' % (epoch, n_epochs, i, len(dataloader), d_loss.item(), 100*d_acc, g_loss.item()))

训练过程

[Epoch 0/50] [Batch 937/938] [D loss: 1.361139, acc: 50%] [G loss: 0.692496]
[Epoch 1/50] [Batch 937/938] [D loss: 1.339481, acc: 50%] [G loss: 0.743710]
[Epoch 2/50] [Batch 937/938] [D loss: 1.324114, acc: 50%] [G loss: 0.876165]
[Epoch 3/50] [Batch 937/938] [D loss: 1.163164, acc: 50%] [G loss: 1.803553]
[Epoch 4/50] [Batch 937/938] [D loss: 1.115438, acc: 50%] [G loss: 3.103937]
[Epoch 5/50] [Batch 937/938] [D loss: 1.045782, acc: 50%] [G loss: 4.149418]
[Epoch 6/50] [Batch 937/938] [D loss: 0.996207, acc: 56%] [G loss: 5.366407]
[Epoch 7/50] [Batch 937/938] [D loss: 0.944309, acc: 67%] [G loss: 5.629416]
[Epoch 8/50] [Batch 937/938] [D loss: 0.923516, acc: 71%] [G loss: 5.145832]
[Epoch 9/50] [Batch 937/938] [D loss: 0.941332, acc: 78%] [G loss: 3.469946]
[Epoch 10/50] [Batch 937/938] [D loss: 0.945412, acc: 75%] [G loss: 3.282685]
[Epoch 11/50] [Batch 937/938] [D loss: 0.862699, acc: 84%] [G loss: 3.509322]
[Epoch 12/50] [Batch 937/938] [D loss: 0.880701, acc: 87%] [G loss: 2.907838]
[Epoch 13/50] [Batch 937/938] [D loss: 0.853650, acc: 92%] [G loss: 4.008491]
[Epoch 14/50] [Batch 937/938] [D loss: 0.814380, acc: 93%] [G loss: 4.354833]
[Epoch 15/50] [Batch 937/938] [D loss: 0.907486, acc: 89%] [G loss: 4.128651]
[Epoch 16/50] [Batch 937/938] [D loss: 0.839670, acc: 87%] [G loss: 3.847980]
[Epoch 17/50] [Batch 937/938] [D loss: 1.118082, acc: 75%] [G loss: 3.573672]
[Epoch 18/50] [Batch 937/938] [D loss: 0.877845, acc: 87%] [G loss: 3.234770]
[Epoch 19/50] [Batch 937/938] [D loss: 1.176042, acc: 75%] [G loss: 4.499653]
[Epoch 20/50] [Batch 937/938] [D loss: 0.942495, acc: 84%] [G loss: 4.823555]
[Epoch 21/50] [Batch 937/938] [D loss: 0.874024, acc: 93%] [G loss: 3.880158]
[Epoch 22/50] [Batch 937/938] [D loss: 0.887224, acc: 90%] [G loss: 3.924105]
[Epoch 23/50] [Batch 937/938] [D loss: 0.876955, acc: 89%] [G loss: 4.332885]
[Epoch 24/50] [Batch 937/938] [D loss: 1.164700, acc: 79%] [G loss: 5.855463]
[Epoch 25/50] [Batch 937/938] [D loss: 0.824182, acc: 100%] [G loss: 3.745309]
[Epoch 26/50] [Batch 937/938] [D loss: 0.991236, acc: 87%] [G loss: 4.963309]
[Epoch 27/50] [Batch 937/938] [D loss: 0.906700, acc: 92%] [G loss: 5.675440]
[Epoch 28/50] [Batch 937/938] [D loss: 0.864558, acc: 93%] [G loss: 5.964598]
[Epoch 29/50] [Batch 937/938] [D loss: 0.788707, acc: 96%] [G loss: 7.074716]
[Epoch 30/50] [Batch 937/938] [D loss: 1.044333, acc: 84%] [G loss: 4.304685]
[Epoch 31/50] [Batch 937/938] [D loss: 0.797054, acc: 100%] [G loss: 5.197765]
[Epoch 32/50] [Batch 937/938] [D loss: 0.824380, acc: 100%] [G loss: 5.913801]
[Epoch 33/50] [Batch 937/938] [D loss: 0.978360, acc: 87%] [G loss: 3.314190]
[Epoch 34/50] [Batch 937/938] [D loss: 1.014248, acc: 78%] [G loss: 8.149563]
[Epoch 35/50] [Batch 937/938] [D loss: 1.352330, acc: 68%] [G loss: 8.068608]
[Epoch 36/50] [Batch 937/938] [D loss: 0.906131, acc: 89%] [G loss: 7.385222]
[Epoch 37/50] [Batch 937/938] [D loss: 0.813954, acc: 98%] [G loss: 5.816649]
[Epoch 38/50] [Batch 937/938] [D loss: 0.840815, acc: 98%] [G loss: 6.768866]
[Epoch 39/50] [Batch 937/938] [D loss: 0.864865, acc: 90%] [G loss: 2.277655]
[Epoch 40/50] [Batch 937/938] [D loss: 0.810660, acc: 93%] [G loss: 6.076533]
[Epoch 41/50] [Batch 937/938] [D loss: 1.189352, acc: 78%] [G loss: 4.746247]
[Epoch 42/50] [Batch 937/938] [D loss: 0.823831, acc: 90%] [G loss: 9.117917]
[Epoch 43/50] [Batch 937/938] [D loss: 0.975088, acc: 85%] [G loss: 2.690667]
[Epoch 44/50] [Batch 937/938] [D loss: 0.911645, acc: 89%] [G loss: 6.431296]
[Epoch 45/50] [Batch 937/938] [D loss: 1.214794, acc: 65%] [G loss: 5.860756]
[Epoch 46/50] [Batch 937/938] [D loss: 0.849733, acc: 98%] [G loss: 4.305855]
[Epoch 47/50] [Batch 937/938] [D loss: 0.910819, acc: 90%] [G loss: 6.148373]
[Epoch 48/50] [Batch 937/938] [D loss: 0.828892, acc: 96%] [G loss: 9.507065]
[Epoch 49/50] [Batch 937/938] [D loss: 1.086049, acc: 84%] [G loss: 5.026798]

模型效果

第一次输出的图像
第一次输出图像
25轮输出的图像
25轮输出的图像
最后输出的图像
最后输出的图像
通过图像可以发现,不同类型的数字间有很大区别,SGAN可以生成的很好

总结与心得体会

SGAN对于GAN的改进,更像是一个拥有着部分共同权重的一组小模型,可以让每个分类的图像生成的更加精确,避免生成的图像同时拥有着几种手势的特点,有点不伦不类。
目前的判别器中,无法对生成的图像打上准确的标签,这样应该会影响生成的精度,如果可以结合CGAN,直接让生成器也学习不同分类的特点,然后让判别器精确的区分,应该可以得到一个更精确的条件生成网络。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1848185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

没有超头、最低价的视频号618战况如何?有何趋势变化?| 视频号618观察

转眼618大促已接近尾声,今年的你有剁手哪些好物吗?对618的整体感觉又是如何呢? 这是12年来,第一个电商平台没有预售付定金的618,当然或许此后的双11、每一次大促也将逐渐回归传统,回归本质。 而对于视频号来…

【八股系列】为什么组件中的 data 必须是一个函数,然后 return 一个对象,而 new Vue 实例里,data 可以直接是一个对象?

🎉 博客主页:【剑九 六千里-CSDN博客】 🎨 上一篇文章:【点击一个按钮,浏览器会做些什么事情【呈现效果时流程】?(js)】 🎠 系列专栏:【面试题-八股系列】 💖 感谢大家点…

深度学习前10节

1.机器学习的流程 (1)数据获取 (2)特征工程 (3)建立模型 (4)评估与应用 2.特征工程的作用 (1)数据特征决定了模型的上限 (2)预处理和特征提取是最核心的 &…

【盘点】8大电商选品思路,实操策略大公开!

1、以人选品 顾名思义,先确定想做的目标人群,再挖掘人群的需求。比如,小个子,这种细分市场,这里的人代表的是一个群体,可以是职业,可以是年龄段可以是一种称呼。如果未能明确目标市场和消费者需…

Linux CentoS安装RabbitMQ:一键安装指南

有两种安装方法,官方推荐使用 docker安装RabbitMQ 一、Docker安装RabbitMQ 1、安装docker 参考我之前的文章:Centos7.5搭建docker并且部署Lnmp环境(小白入门docoker)_centos7.5安装docker和docker-compose-CSDN博客 2、安装Ra…

【C++】数据类型、函数、头文件、断点调试、输入输出、条件与分支、VS项目设置

四、基本概念 这部分和C语言重复的部分就简写速过,因为我之前写过一个C语言的系列,非常详细。C和C这些都是一样的,所以这里不再一遍遍重复码字了。感兴趣的同学可以翻看我之前的C语言系列文章。 1、数据类型 编程的本质就是操作数据。 操…

容器之对齐构件

代码&#xff1a; #include <gtk-2.0/gtk/gtk.h> #include <glib-2.0/glib.h> #include <gtk-2.0/gdk/gdkkeysyms.h> #include <stdio.h>int main(int argc, char *argv[]) {gtk_init(&argc, &argv);GtkWidget *window;window gtk_window_ne…

Docker基本使用和认识

目录 基本使用 镜像仓库 镜像操作 Docker 如何实现镜像 1) namespace 2) cgroup 3) LXC Docker常见的网络类型 bridge网络如何实现 基本使用 镜像仓库 镜像仓库登录 1)docker login 后面不指定IP地址&#xff0c;则默认登录到 docker hub 上 退出 2)docker logo…

Latex的参考文献中显示三个问号???——解决办法

1、问题描述 在使用spring模板&#xff0c;并引用book时&#xff0c;末尾的引文地方出现三个???由于使用的bibtex是直接从谷歌学术中导出来的&#xff0c;其中仅包含作者&#xff0c;书名&#xff0c;出版社&#xff0c;年份等&#xff0c;缺少了重要的信息。结果导致在出版…

容器之布局容器的演示

代码; #include <gtk-2.0/gtk/gtk.h> #include <glib-2.0/glib.h> #include <gtk-2.0/gdk/gdkkeysyms.h> #include <stdio.h>void change_image(GtkFileChooserButton *filebutton, // GdkEvent *event,GtkImage *image) {gtk_image_set_from_file(im…

如何一键下载整个城市路网?

我们在《200城市的CAD建筑与路网下载》一文中&#xff0c;为你分享了下载CAD建筑与路网的方法。 现在&#xff0c;再为你分享一键下载整个城市路网地图的方法&#xff0c;并为你分享已经下载好的北京、上海、广州和深圳等几个城市的路网地图图片&#xff0c;请在文末查看获取该…

Linux开发讲课7---Linux sysfs文件系统

一、sysfs文件系统介绍 Sysfs&#xff08;System Filesystem&#xff09;是Linux内核提供的一种虚拟文件系统&#xff0c;用于向用户空间公开有关设备和驱动程序的信息。它类似于/proc文件系统&#xff0c;但是专注于设备和驱动程序信息&#xff0c;而非进程信息。 Sysfs通过文…

phar反序列化及绕过

目录 一、什么是phar phar://伪协议格式&#xff1a; 二、phar结构 1.stub phar&#xff1a;文件标识。 格式为 xxx; *2、manifest&#xff1a;压缩文件属性等信息&#xff0c;以序列化存 3、contents&#xff1a;压缩文件的内容。 4、signature&#xff1a;签名&#…

Android开发系列(五)Jetpack Compose之Icon Image

Icon是用于在界面上显示矢量图标的组件。它提供了很多内置的矢量图标&#xff0c;也支持自定义图标。要使用Icon组件&#xff0c;可以通过指定图标资源的名称或引用来创建一个Icon对象。例如&#xff0c;使用Icons.Default.Home来创建一个默认风格的首页图标。可以通过设置图标…

免费体验软件开发生产线 CodeArts

软件开发生产线 CodeArts 一站式、全流程、安全可信的软件开发生产线&#xff0c;开箱即用&#xff0c;内置华为多年研发最佳实践&#xff0c;助力效能倍增和数字化转型 免费试用体验版套餐&#xff0c;50人内免费试用 功能特性 Scrum和看板需求模型 代码托管 代码检查&am…

DN-DETR

可以看到&#xff0c;与 DAB-DETR 相比&#xff0c;最大的差别仍然在 decoder 处&#xff0c;主要是 query 的输入。DN-DETR 认为可以把对 offsets 的学习&#xff0c;看作一种对噪声学习的过程&#xff0c;因此&#xff0c;可以直接在 GT 周围生成一些 noised boxes&#xff0…

手写方法实现整型例如:123与字符串例如:“123“相互转化(下篇)

目录 一、前言 二、整型转化为字符串 1. 初始化变量 2.数字1转字符1 3.取出value中的每一项数字 4.将字符放入字符数组中 5.最终代码 三、最后 一、前言 本篇文章紧跟上篇文章&#xff0c;本片内容为整型转化为字符串类型。至于我为什么要分两篇文章&#xff0c;主要…

ATA-4051C高压功率放大器在压电电机中的作用是什么

压电电机是一种特殊的电机&#xff0c;其工作原理基于压电效应&#xff0c;这是一种将电能转化为机械振动的现象。压电电机通常用于精密定位、振动控制和声波生成等应用。为了驱动和控制压电电机&#xff0c;需要高压功率放大器。下面将介绍高压功率放大器在压电电机中的作用&a…

信创CPU秘史(上):大厂销售的路子有多野?

最近接到一份金融行业粉丝的投稿&#xff0c;内容之奇令人咋舌&#xff0c;尽是些闻所未闻的新知识。无论是内容本身&#xff0c;还是获取内容的渠道&#xff0c;都非常有意思。今年我们把舞台交给老金&#xff0c;一起来听听信创大厂间的那些小秘密。 大家好&#xff0c;我叫老…

你知道什么是微调吗?大模型为什么要微调?以及大模型微调的原理是什么?

“ 预训练(pretrain)微调(finetuning)&#xff0c;是目前主流的范式**”** 在学习大模型的过程中&#xff0c;怎么设计神经网络和怎么训练模型是一个重要又基础的操作。 但与之对应的微调也是一个非常重要的手段&#xff0c;这里就着重讲一下为什么要微调&#xff0c;其优点是…