ACGAN

news2024/12/23 1:55:27

CGAN通过在生成器和判别器中均使用标签信息进行训练,不仅能产生特定标签的数据,还能够提高生成数据的质量;SGAN(Semi-Supervised GAN)通过使判别器/分类器重建标签信息来提高生成数据的质量。既然这两种思路都可以提高生成数据的质量,于是ACGAN综合了以上两种思路,既使用标签信息进行训练,同时也重建标签信息,结合CGAN和SGAN的优点,从而进一步提升生成样本的质量,并且还能根据指定的标签相应的样本。

1. ACGAN的网络结构为:

ACGAN的网络结构框图

        生成器输入包含C_vector和Noise_data两个部分,其中C_vector为训练数据标签信息的One-hot编码张量,其形状为:(batch_size, num_class) ;Noise_data的形状为:(batch_size, latent_dim)。然后将两者进行拼接,拼接完成后,得到的输入张量为:(batch_size, num_class + latent_dim)。生成器的的输出张量为:(batch_size, channel, Height, Width)。

        判别器输入为:(batch_size, channel, Height, Width); 判别的器的输出为两部分,一部分是源数据真假的判断,形状为:(batch_size, 1),一部分是输入数据的分类结果,形状为:(batch_size, class_num)。因此判别器的最后一层有两个并列的全连接层,分别得到这两部分的输出结果,即判别器的输出有两个张量(真假判断张量和分类结果张量)。

2. ACGAN的损失函数:

        对于判别器而言,既希望分类正确,又希望能正确分辨数据的真假;对于生成器而言,也希望能够分类正确,当时希望判别器不能正确分辨假数据。

D_real, C_real = Discriminator( real_imgs)         # real_img 为输入的真实训练图片

D_real_loss = torch.nn.BCELoss(D_real, Y_real)          #  Y_real为真实数据的标签,真数据都为-1,假数据都为+1

C_real_loss = torch.nn.CrossEntropyLoss(C_real, Y_vec)        # Y_vec为训练数据One-hot编码的标签张量

gen_imgs = Generator(noise, Y_vec)

D_fake, C_fake = Discriminator(gen_imgs)

D_fake_loss = torch.nn.BCELoss(D_fake, Y_fake)

C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)

D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss

生成器的损失函数:  

gen_imgs = Generator(noise, Y_vec)

D_fake, C_fake = Discriminator(gen_imgs)

D_fake_loss = torch.nn.BCELoss(D_fake, Y_real)

C_fake_loss = torch.nn.CrossEntropyLoss(C_fake, Y_vec)

G_loss = D_fake_loss + C_fake_loss

class Discriminator(nn.Module):  # 定义判别器
    def __init__(self, img_size=(64, 64), num_classes=2):  # 初始化方法
        super(Discriminator, self).__init__()  # 继承初始化方法
 
        self.img_size = img_size  # 图片尺寸,默认为(64.64)三通道图片
        self.num_classes = num_classes  # 类别数
 
        self.conv1 = nn.Conv2d(3, 128, 4, 2, 1)  # conv操作
        self.conv2 = nn.Conv2d(128, 256, 4, 2, 1)  # conv操作
        self.bn2 = nn.BatchNorm2d(256)  # bn操作
        self.conv3 = nn.Conv2d(256, 512, 4, 2, 1)  # conv操作
        self.bn3 = nn.BatchNorm2d(512)  # bn操作
        self.conv4 = nn.Conv2d(512, 1024, 4, 2, 1)  # conv操作
        self.bn4 = nn.BatchNorm2d(1024)  # bn操作
        self.leakyrelu = nn.LeakyReLU(0.2)  # leakyrelu激活函数
        self.linear1 = nn.Linear(int(1024 * (self.img_size[0] / 2 ** 4) * (self.img_size[1] / 2 ** 4)), 1)  # linear映射
        self.linear2 = nn.Linear(int(1024 * (self.img_size[0] / 2 ** 4) * (self.img_size[1] / 2 ** 4)),
                                 self.num_classes)  # linear映射
        self.sigmoid = nn.Sigmoid()  # sigmoid激活函数
        self.softmax = nn.Softmax(dim=1)  # softmax激活函数
 
        self._init_weitghts()  # 模型权重初始化
 
    def _init_weitghts(self):  # 定义模型权重初始化方法
        for m in self.modules():  # 遍历模型结构
            if isinstance(m, nn.Conv2d):  # 如果当前结构是conv
                nn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化
                nn.init.constant_(m.bias, 0)  # b设为0
            elif isinstance(m, nn.BatchNorm2d):  # 如果当前结构是bn
                nn.init.constant_(m.weight, 1)  # w设为1
                nn.init.constant_(m.bias, 0)  # b设为0
            elif isinstance(m, nn.Linear):  # 如果当前结构是linear
                nn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化
                nn.init.constant_(m.bias, 0)  # b设为0
 
    def forward(self, x):  # 前传函数
        x = self.conv1(x)  # conv,(n,3,64,64)-->(n,128,32,32)
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = self.conv2(x)  # conv,(n,128,32,32)-->(n,256,16,16)
        x = self.bn2(x)  # bn操作
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = self.conv3(x)  # conv,(n,256,16,16)-->(n,512,8,8)
        x = self.bn3(x)  # bn操作
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = self.conv4(x)  # conv,(n,512,8,8)-->(n,1024,4,4)
        x = self.bn4(x)  # bn操作
        x = self.leakyrelu(x)  # leakyrelu激活函数
        x = torch.flatten(x, 1)  # 三维特征压缩至一位特征向量,(n,1024,4,4)-->(n,1024*4*4)
        # 根据特征向量x,计算图片真假的得分
        validity = self.linear1(x)  # linear映射,(n,1024*4*4)-->(n,1)
        validity = self.sigmoid(validity)  # sigmoid激活函数,将输出压缩至(0,1)
        # 根据特征向量x,计算图片分类的标签
        label = self.linear2(x)  # linear映射,(n,1024*4*4)-->(n,2)
        label = self.softmax(label)  # softmax激活函数,将输出压缩至(0,1)
 
        return (validity, label)  # 返回(图像真假的得分,图片分类的标签)
 
 
class Generator(nn.Module):  # 定义生成器
    def __init__(self, img_size=(64, 64), num_classes=2, latent_dim=100):  # 初始化方法
        super(Generator, self).__init__()  # 继承初始化方法
        self.img_size = img_size  # 图片尺寸,默认为(64.64)三通道图片
        self.num_classes = num_classes  # 类别数
        self.latent_dim = latent_dim  # 输入噪声长度,默认为100
 
        self.linear = nn.Linear(self.latent_dim, 4 * 4 * 1024)  # linear映射
        self.bn0 = nn.BatchNorm2d(1024)  # bn操作
        self.deconv1 = nn.ConvTranspose2d(1024, 512, 4, 2, 1)  # transconv操作
        self.bn1 = nn.BatchNorm2d(512)  # bn操作
        self.deconv2 = nn.ConvTranspose2d(512, 256, 4, 2, 1)  # transconv操作
        self.bn2 = nn.BatchNorm2d(256)  # bn操作
        self.deconv3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)  # transconv操作
        self.bn3 = nn.BatchNorm2d(128)  # bn操作
        self.deconv4 = nn.ConvTranspose2d(128, 3, 4, 2, 1)  # transconv操作
        self.relu = nn.ReLU(inplace=True)  # relu激活函数
        self.tanh = nn.Tanh()  # tanh激活函数
        self.embedding = nn.Embedding(self.num_classes, self.latent_dim)  # embedding操作
 
        self._init_weitghts()  # 模型权重初始化
 
    def _init_weitghts(self):  # 定义模型权重初始化方法
        for m in self.modules():  # 遍历模型结构
            if isinstance(m, nn.ConvTranspose2d):  # 如果当前结构是transconv
                nn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化
                nn.init.constant_(m.bias, 0)  # b设为0
            elif isinstance(m, nn.BatchNorm2d):  # 如果当前结构是bn
                nn.init.constant_(m.weight, 1)  # w设为1
                nn.init.constant_(m.bias, 0)  # b设为0
            elif isinstance(m, nn.Linear):  # 如果当前结构是linear
                nn.init.normal_(m.weight, 0, 0.02)  # w采用正态分布初始化
                nn.init.constant_(m.bias, 0)  # b设为0
 
    def forward(self, input: tuple):  # 前传函数
        noise, label = input  # 从输入的元组中获取噪声向量和标签信息
        label = self.embedding(label)  # 标签信息经过embedding操作,变成与噪声向量尺寸相同的稠密向量
        z = torch.multiply(noise, label)  # 噪声向量与标签稠密向量相乘,得到带有标签信息的噪声向量
        z = self.linear(z)  # linear映射,(n,100)-->(n,1024*4*4)
        z = z.view((-1, 1024, int(self.img_size[0] / 2 ** 4),
                    int(self.img_size[1] / 2 ** 4)))  # 一维特征向量扩展至三维特征,(n,1024*4*4)-->(n,1024,4,4)
        z = self.bn0(z)  # bn操作
        z = self.relu(z)  # relu激活函数
        z = self.deconv1(z)  # trainsconv操作,(n,1024,4,4)-->(n,512,8,8)
        z = self.bn1(z)  # bn操作
        z = self.relu(z)  # relu激活函数
        z = self.deconv2(z)  # trainsconv操作,(n,512,8,8)-->(n,256,16,16)
        z = self.bn2(z)  # bn操作
        z = self.relu(z)  # relu激活函数
        z = self.deconv3(z)  # trainsconv操作,(n,256,16,16)-->(n,128,32,32)
        z = self.bn3(z)  # bn操作
        z = self.relu(z)  # relu激活函数
        z = self.deconv4(z)  # trainsconv操作,(n,128,32,32)-->(n,3,64,64)
        z = self.tanh(z)  # tanh激活函数,将输出压缩至(-1,1)
 
        return z  # 返回生成图像

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1050443.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Oracle - 多区间按权重取值逻辑

啰嗦: 其实很早就遇到过类似问题,也设想过,不过一致没实际业务需求,也就耽搁了;最近有业务提到了,和同事讨论,各有想法,所以先把逻辑整理出来,希望有更好更优的解决方案;…

传统遗产与技术相遇,古彝文的数字化与保护

古彝文是中国彝族的传统文字,具有悠久的历史和文化价值。然而,由于古彝文的形状复杂且没有标准化的字符集,对其进行文字识别一直是一项具有挑战性的任务。本文介绍了古彝文合合信息的文字识别技术,旨在提高古彝文的自动识别准确性…

linux 和 windows的換行符不兼容問題

linux 和 windows的換行符: 1.vim 模式下,執行命令: :set ffunix idea中設置code style

从零手搓一个【消息队列】项目设计、需求分析、模块划分、目录结构

文章目录 一、需求分析1, 项目简介2, BrokerServer 核心概念3, BrokerServer 提供的核心 API4, 交换机类型5, 持久化存储6, 网络通信7, TCP 连接的复用8, 需求分析小结 二、模块划分三、目录结构 提示:是正在努力进步的小菜鸟一只,如有大佬发现文章欠佳之…

Linux账户组管理及权限练习

1.使用id命令查看root账户信息 [rootserver ~]# id root 用户id0(root) 组id0(root) 组0(root) 2.使用id命令查看自己的普通账户信息 [rootserver ~]# id kxy 用户id1000(kxy) 组id1000(kxy) 组1000(kxy),10(wheel) 3.新建账户test1,并查看账户信息: [ro…

安装python扩展库

博主:命运之光 专栏:Python程序设计 Python扩展库安装 Python提供了丰富的标准库(不需要安装) ,还支持大量的第三方扩展库,它们数量众多、功能强大、涉及面广、使用方便,得到各行业领域工程师的…

千问的大模型KnowHow

卷友们好,我是rumor。 通义千问昨天放出了14b参数的模型,还有一份比较详尽的技术报告,包含作者们训练8个模型的宝贵经验。 同时他们开源的13B比起开源的SOTA也有不少提升: 今天我们就来一起白嫖,更多细节请移步原文&am…

Mybatis 日志(Apache Commons Logging)

之前我们介绍了使用JDK Log打印Mybatis运行时的日志;本篇我们介绍使用Apache Commons Logging打印Mybatis运行时的日志。 如何您对Mybatis中使用JDK Log不太了解,可以参考: Mybatis 日志(JDK Log)https://blog.csdn.net/m1729339749/articl…

上海市小机灵数学比赛回顾和五年级1-15届真题和答案学习资料

从2017年到现在,之前卷得非常厉害的上海市一系列与升学挂钩的竞赛如“小机灵杯、走美杯、希望杯、中环杯”等比赛都成为了竞赛历史的一部分。 尽管教育部门明确规定,学校不得将相关比赛的成绩作为学生评价和选拔的要素,但是许多家长仍按对于…

【STM32基础 CubeMX】从0带你点灯

文章目录 前言一、GPIO的概念二、CubeMX配置GPIO2.1 基础配置2.2 GPIO配置 三、点灯代码讲解3.1 cubemx生成的代码3.2 1个库函数 四、LED闪烁总结 前言 一、GPIO的概念 STM32是一系列微控制器芯片的品牌,它们用于控制各种电子设备。其中的GPIO是通用输入/输出端口的…

Spring IOC(控制反转)与DI(依赖注入)

定义 IOC(Inversion of Control),即控制反转:对象的创建控制权不再由程序来执行,而是交由给Spring容器处理。简单的说程序不需要进行new操作,对象直接由Spring容器自动创建。 DI(Dependency Injection),即依赖注入&am…

窗口类介绍

目录 Qwidget QDialog QMessageBox QFileDialog QFontDialog QColorDialog QInputDialog QProgressDialog QMainWindow 菜单栏 工具栏 状态栏 停靠窗口 窗口布局 Qwidget 常用的一些函数包括: 设置窗口的大小,尺寸,得到对应的…

蓝海彤翔亮相2023新疆网络文化节重点项目“新疆动漫节”

9月22日上午,2023新疆网络文化节重点项目“新疆动漫节”(以下简称“2023新疆动漫节”)在克拉玛依科学技术馆隆重开幕,蓝海彤翔作为国内知名的文化科技产业集团应邀参与此次活动,并在美好新疆e起向未来动漫展映区设置展…

Ubuntu为什么键盘会出现乱字符

今天上午起来只是要简单打一个命令,需要输入一个"双引号,但是总是显示,我一开始以为是中了病毒,把键盘给改了,后来发现虚惊一场:出现这个原因是因为ubuntu的键盘设置有问题。 我把键盘设置为英国英语…

C++简单实现红黑树

目录 一、概念 二、红黑树的性质 三、红黑树的定义 四、红黑树的插入操作 情况一(叔叔节点存在且为红色)——变色向上调整: 情况二(叔叔节点不存在或为黑色)——旋转变色: 2.1叔叔节点不存在 2.2叔叔…

在 SDXL 上用 T2I-Adapter 实现高效可控的文生图

T2I-Adapter 是一种高效的即插即用模型,其能对冻结的预训练大型文生图模型提供额外引导。T2I-Adapter 将 T2I 模型中的内部知识与外部控制信号结合起来。我们可以根据不同的情况训练各种适配器,实现丰富的控制和编辑效果。 同期的 ControlNet 也有类似的…

Windows Server 2012 R2 安装 .NET Framework 4.6.1

服务器操作系统是 Windows Server 2012 R2 版本,在安装 .NET Framework 4.6.1 过程中出现报错,报错截图如下: 通过上报报错可以发现是缺少对应的 KB2919355 更新,只有安装了此依赖才能在 Windows 8.1 或 Windows Server 2012 R2 …

中秋海报设计技巧大公开

中秋节即将来临,为了帮助大家设计出完美的海报,本文将提供详细的步骤和技巧,让你轻松打造出令人满意的海报作品。 步骤一:注册并登录乔拓云后台,进入海报中心页面。 在制作海报之前,你需要先注册并登录乔拓…

yolov5-6.0使用改进

代码版本V6.0 源码 YOLOv5 v6.0 release 改动速览 推出了新的 P5 和 P6 ‘Nano’ 模型: YOLOV5n和YOLOV5n6。 Nano 将 YOLOv5s 的深度倍数保持为 0.33,但将 YOLOv5 的宽度倍数从 0.50 降低到 0.25,从而将参数从 7.5M 降低到 1.9M&#xff0…

Linux shell 脚本中, $@ 和$# 分别是什么意思

Linux shell 脚本中, 和 和 和# 分别是什么意思? $:表示所有脚本参数的内容 $#:表示返回所有脚本参数的个数。 示例:编写如下shell脚本,保存为test.sh #!/bin/sh echo “number:$#” echo “argume:$” 执行…