MATLAB环境下生成对抗网络系列(11种)

news2024/11/15 17:58:41

为了构建有效的图像深度学习模型,数据增强是一个非常行之有效的方法。图像的数据增强是一套使用有限数据来提高训练数据集质量和规模的数据空间解决方案。广义的图像数据增强算法包括:几何变换、颜色空间增强、核滤波器、混合图像、随机擦除、特征空间增强、对抗训练、生成对抗网络和风格迁移等内容。增强的数据代表一个分布覆盖性更广、可靠性更高的数据点集,使用增强数据能够有效增加训练样本的多样性,最小化训练集和验证集以及测试集之间的距离。使用数据增强后的数据集训练模型,可以达到提升模型稳定性、泛化能力的效果。

使用生成对抗网络GAN提取原数据集特征,对抗生成新的目标域图像,已成为众多学者在数据增强技术研究中的优选方法。相比于传统的图像数据增强方法,通过基于GAN的生成式建模技术进行数据增强的思路来源于博弈论中的二人零和博弈,由网络中包含的生成器和判别器利用对抗学习的方法来指导网络训练,在两个网络对抗过程中估计原始数据样本的分布并生成与之相似的新数据。

近期的研究在原始生成对抗网络框架的基础上又提出了多种不同的改进方案,通过设计不同的神经网络架构和损失函数等手段不断提升生成对抗网络的性能。生成对抗网络应用在图像数据增强任务上的思想主要是其通过生成新的训练数据来扩充模型的训练数据,通过样本空间的扩充实现图像分类等任务效果的提升。但目前基于GAN的图像数据增强技术普遍存在模型收敛不稳定、生成图像质量低等问题,如何正确引入高频信息,提升图像数据质量是破解这一系列问题的关键。

MATLAB环境配置如下:

  • MATLAB 2021b
  • Deep Learning Toolbox
  • Parallel Computing Toolbox (optional for GPU usage)

目录如下

  • Generative Adversarial Network (GAN) [paper]
  • Least Squares Generative Adversarial Network (LSGAN) [paper]
  • Deep Convolutional Generative Adversarial Network (DCGAN) [paper]
  • Conditional Generative Adversarial Network (CGAN)[paper]
  • Auxiliary Classifier Generative Adversarial Network (ACGAN) [paper]
  • InfoGAN [paper]
  • Adversarial AutoEncoder (AAE)[paper]
  • Pix2Pix[paper]
  • Wasserstein Generative Adversarial Network (WGAN) [paper]
  • Semi-Supervised Generative Adversarial Network (SGAN) [paper]
  • CycleGAN [paper]
  • DiscoGAN [paper]

部分代码如下:

首先,导入相关的mnist手写数字图

load('mnistAll.mat')

然后对训练、测试图像进行预处理

trainX = preprocess(mnist.train_images); 
trainY = mnist.train_labels;%训练标签
testX = preprocess(mnist.test_images); 
testY = mnist.test_labels;%测试标签

preprocess为归一化函数,如下

function x = preprocess(x)
x = double(x)/255;
x = (x-.5)/.5;
x = reshape(x,28*28,[]);
end

然后进行参数设置,包括潜变量空间维度,batch_size大小,学习率,最大迭代次数等等

settings.latent_dim = 10;
settings.batch_size = 32; settings.image_size = [28,28,1]; 
settings.lrD = 0.0002; settings.lrG = 0.0002; settings.beta1 = 0.5;
settings.beta2 = 0.999; settings.maxepochs = 50;

下面进行编码器初始化,代码还是很容易看懂的

paramsEn.FCW1 = dlarray(initializeGaussian([512,...
     prod(settings.image_size)],.02));
paramsEn.FCb1 = dlarray(zeros(512,1,'single'));
paramsEn.FCW2 = dlarray(initializeGaussian([512,512]));
paramsEn.FCb2 = dlarray(zeros(512,1,'single'));
paramsEn.FCW3 = dlarray(initializeGaussian([2*settings.latent_dim,512]));
paramsEn.FCb3 = dlarray(zeros(2*settings.latent_dim,1,'single'));

解码器初始化

paramsDe.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDe.FCb1 = dlarray(zeros(512,1,'single'));
paramsDe.FCW2 = dlarray(initializeGaussian([512,512]));
paramsDe.FCb2 = dlarray(zeros(512,1,'single'));
paramsDe.FCW3 = dlarray(initializeGaussian([prod(settings.image_size),512]));
paramsDe.FCb3 = dlarray(zeros(prod(settings.image_size),1,'single'));

判别器初始化

paramsDis.FCW1 = dlarray(initializeGaussian([512,settings.latent_dim],.02));
paramsDis.FCb1 = dlarray(zeros(512,1,'single'));
paramsDis.FCW2 = dlarray(initializeGaussian([256,512]));
paramsDis.FCb2 = dlarray(zeros(256,1,'single'));
paramsDis.FCW3 = dlarray(initializeGaussian([1,256]));
paramsDis.FCb3 = dlarray(zeros(1,1,'single'));

%平均梯度和平均梯度平方数组
avgG.Dis = []; avgGS.Dis = []; avgG.En = []; avgGS.En = [];
avgG.De = []; avgGS.De = [];

开始训练

dlx = gpdl(trainX(:,1),'CB');
dly = Encoder(dlx,paramsEn);
numIterations = floor(size(trainX,2)/settings.batch_size);
out = false; epoch = 0; global_iter = 0;
while ~out
    tic; 
    shuffleid = randperm(size(trainX,2));
    trainXshuffle = trainX(:,shuffleid);
    fprintf('Epoch %d\n',epoch) 
    for i=1:numIterations
        global_iter = global_iter+1;
        idx = (i-1)*settings.batch_size+1:i*settings.batch_size;
        XBatch=gpdl(single(trainXshuffle(:,idx)),'CB');

        [GradEn,GradDe,GradDis] = ...
                dlfeval(@modelGradients,XBatch,...
                paramsEn,paramsDe,paramsDis,settings);

        % 更新判别器网络参数
        [paramsDis,avgG.Dis,avgGS.Dis] = ...
            adamupdate(paramsDis, GradDis, ...
            avgG.Dis, avgGS.Dis, global_iter, ...
            settings.lrD, settings.beta1, settings.beta2);

        % 更新编码器网络参数
        [paramsEn,avgG.En,avgGS.En] = ...
            adamupdate(paramsEn, GradEn, ...
            avgG.En, avgGS.En, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        % 更新解码器网络参数
        [paramsDe,avgG.De,avgGS.De] = ...
            adamupdate(paramsDe, GradDe, ...
            avgG.De, avgGS.De, global_iter, ...
            settings.lrG, settings.beta1, settings.beta2);
        
        if i==1 || rem(i,20)==0
            progressplot(paramsDe,settings);
            if i==1 
                h = gcf;
                % 捕获图像
                frame = getframe(h); 
                im = frame2im(frame); 
                [imind,cm] = rgb2ind(im,256); 
                % 写入 GIF 文件
                if epoch == 0
                  imwrite(imind,cm,'AAEmnist.gif','gif', 'Loopcount',inf); 
                else 
                  imwrite(imind,cm,'AAEmnist.gif','gif','WriteMode','append'); 
                end 
            end
        end
        
    end

    elapsedTime = toc;
    disp("Epoch "+epoch+". Time taken for epoch = "+elapsedTime + "s")
    epoch = epoch+1;
    if epoch == settings.maxepochs
        out = true;
    end    
end

下面是完整的辅助函数

模型的梯度计算函数

function [GradEn,GradDe,GradDis]=modelGradients(x,paramsEn,paramsDe,paramsDis,settings)
dly = Encoder(x,paramsEn);
latent_fake = dly(1:settings.latent_dim,:)+...
    dly(settings.latent_dim+1:2*settings.latent_dim)*...
    randn(settings.latent_dim,settings.batch_size);
latent_real = gpdl(randn(settings.latent_dim,settings.batch_size),'CB');

%训练判别器
d_output_fake = Discriminator(latent_fake,paramsDis);
d_output_real = Discriminator(latent_real,paramsDis);
d_loss = -.5*mean(log(d_output_real+eps)+log(1-d_output_fake+eps));

%训练编码器和解码器
x_ = Decoder(latent_fake,paramsDe);
g_loss = .999*mean(mean(.5*(x_-x).^2,1))-.001*mean(log(d_output_fake+eps));

%对于每个网络,计算关于损失函数的梯度
[GradEn,GradDe] = dlgradient(g_loss,paramsEn,paramsDe,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);
end

提取数据函数

function x = gatext(x)
x = gather(extractdata(x));
end

GPU深度学习数组wrapper函数

function dlx = gpdl(x,labels)
dlx = gpuArray(dlarray(x,labels));
end

权重初始化函数

function parameter = initializeGaussian(parameterSize,sigma)
if nargin < 2
    sigma = 0.05;
end
parameter = randn(parameterSize, 'single') .* sigma;
end

dropout函数

function dly = dropout(dlx,p)
if nargin < 2
    p = .3;
end
[n,d] = rat(p);
mask = randi([1,d],size(dlx));
mask(mask<=n)=0;
mask(mask>n)=1;
dly = dlx.*mask;
end

编码器函数

function dly = Encoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
end

解码器函数

function dly = Decoder(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = leakyrelu(dly,.2);
dly = tanh(dly);
end

判别器函数

function dly = Discriminator(dlx,params)
dly = fullyconnect(dlx,params.FCW1,params.FCb1);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW2,params.FCb2);
dly = leakyrelu(dly,.2);
dly = fullyconnect(dly,params.FCW3,params.FCb3);
dly = sigmoid(dly);
end

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任
《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1439002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

寒假作业2024.2.6

1.现有无序序列数组为23,24,12,5,33,5347&#xff0c;请使用以下排序实现编程 函数1:请使用冒泡排序实现升序排序 函数2:请使用简单选择排序实现升序排序 函数3:请使用直接插入排序实现升序排序 函数4:请使用插入排序实现升序排序 #include <stdio.h> #include <stdl…

一个坐标系查询网站python获取所有坐标系

技术路线选择 我是使用的vue 3开发的网页界面&#xff0c;element-plus构建网页组件&#xff0c;openlayer展示地图&#xff0c;express提供后端API&#xff0c;vercel进行在线部署。 python获取所有坐标系 想要展示所有坐标系&#xff0c;那需要先获取坐标系&#xff0c;怎么…

【开源】基于JAVA+Vue+SpringBoot的贫困地区人口信息管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 人口信息管理模块2.2 精准扶贫管理模块2.3 特殊群体管理模块2.4 案件信息管理模块2.5 物资补助模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 人口表3.2.2 扶贫表3.2.3 特殊群体表3.2.4 案件表3.2.5 物资补助表 四…

机器人学、机器视觉与控制 上机笔记(2.1章节)

机器人学、机器视觉与控制 上机笔记&#xff08;2.1章节&#xff09; 1、前言2、本篇内容3、代码记录3.1、新建se23.2、生成坐标系3.3、将T1表示的变换绘制3.4、完整绘制代码3.5、获取点*在坐标系1下的表示3.6、相对坐标获取完整代码 4、结语 1、前言 工作需要&#xff0c;想同…

HTTP协议笔记

HTTP协议笔记 参考&#xff1a; &#xff08;建议精读&#xff09;HTTP灵魂之问&#xff0c;巩固你的 HTTP 知识体系 《透视 HTTP 协议》——chrono 目录&#xff1a; 1、说说你对HTTP的了解吧。  1. HTTP状态码。  2. HTTP请求头和响应头&#xff0c;其中包括cookie、跨域响…

AcWing 1238 日志统计(双指针算法)

题目概述 小明维护着一个程序员论坛。现在他收集了一份”点赞”日志&#xff0c;日志共有 N 行。 其中每一行的格式是&#xff1a; ts id表示在 ts 时刻编号 id 的帖子收到一个”赞”。 现在小明想统计有哪些帖子曾经是”热帖”。 如果一个帖子曾在任意一个长度为 D 的时间段…

《MySQL 简易速速上手小册》第1章:MySQL 基础和安装(2024 最新版)

文章目录 1.1 MySQL 概览&#xff1a;版本、特性和生态系统1.1.1 基础知识1.1.2 重点案例1.1.3 拓展案例 1.2 安装和配置 MySQL1.2.1 基础知识1.2.2 安装步骤1.2.3 重点案例1.2.4 拓展案例 1.3 基础命令和操作1.3.1 基础知识1.3.2 重点案例1.3.3 拓展案例 1.1 MySQL 概览&#…

JUC ThreadLocal

文章目录 ThreadLocal ^1.2^ 的作用使用场景示例1ThreadLocal 变量初始化ThreadLocal 源码分析源码分析总结 内存泄漏问题示例说明new Thread 方式 执行结果pool 方式执行结果原因解析总结 ThreadLocal 1.2 的作用 ThreadLocal 为每个线程提供单独的变量副本。每个变量副本都是…

史上最全嵌入式(学习路线、应用开发、驱动开发、推荐书籍、软硬件基础)

废话不多说直接上思维导图&#xff01; 如果有觉得图片看不清楚的&#xff0c;有疑问的&#xff0c;可在评论区进行留言&#xff01; 群号&#xff1a; 228447240 嵌入式总括 嵌入式书籍推荐 嵌入式软件知识 嵌入式硬件知识 嵌入式应用开发 嵌入式驱动开发 嵌入式视频推荐: 韦…

WebSocket相关问题

1.WebSocket是什么&#xff1f;和HTTP的区别&#xff1f; WebSocket是一种基于TCP连接的全双工通信协议&#xff0c;客户端和服务器仅需要一次握手&#xff0c;两者之间就可以创建持久性的连接&#xff0c;并且支持双向数据的传输。WebSocket和HTTP都是基于TCP的应用层协议&am…

【PyTorch][chapter 15][李宏毅深度学习][Neighbor Embedding-LLE]

前言&#xff1a; 前面讲的都是线性降维&#xff0c;本篇主要讨论一下非线性降维. 流形学习&#xff08;mainfold learning&#xff09;是一类借鉴了拓扑流行概念的降维方法. 如上图,欧式距离上面 A 点跟C点更近&#xff0c;距离B 点较远 但是从图形拓扑结构来看&#xff0c; …

书生·浦语大模型全链路开源体系

1&#xff0c;简述大模型的定义与特点&#xff1a; 大模型是指参数数量大于10亿的模型&#xff0c;它的特点包括&#xff1a;模型规模大&#xff0c;数据规模大&#xff0c;计算规模大和任务数量 2. 分析大模型成为通用人工智能的重要途径的原因&#xff1a; 大模型能够从大…

2023年的技术变革,我不是破坏大环境的人

文章目录 前言2023年的技术变革人工智能的崛起元宇宙的跌落物联网的渗入 技术变革的背后技术变革的影响积极的影响负面的影响 技术变革带来的思考 前言 2023无疑是一个充满变革和创新的一年&#xff0c;这背后离不开技术的发展和进步。不论是人工智能的崛起&#xff0c;还是元…

[word] word表格内容自动编号 #经验分享#微信#其他

word表格内容自动编号 在表格中的内容怎么样自动编号&#xff1f;我们都知道Word表格和Excel表格有所不同&#xff0c;Excel表格可以轻松自动编号&#xff0c;那么在Word表格中如何自动编号呢&#xff1f; 1、选中内容后&#xff0c;点击段落-自动编号&#xff0c;选择其中一…

数据结构——C/栈和队列

&#x1f308;个人主页&#xff1a;慢了半拍 &#x1f525; 创作专栏&#xff1a;《史上最强算法分析》 | 《无味生》 |《史上最强C语言讲解》 | 《史上最强C练习解析》 &#x1f3c6;我的格言&#xff1a;一切只是时间问题。 ​ 1.栈 1.1栈的概念及结构 栈&#xff1a;一种特…

计算机缺失concrt140.dll怎么修复?分享5种有效的修复方法

在计算机系统运行过程中&#xff0c;如果发现无法找到“concrt140.dll”这个特定的动态链接库文件&#xff0c;可能会引发一系列问题和故障。首先&#xff0c;我们需要了解“concrt140.dll”是Microsoft Visual Studio中用于实现并行计算框架的重要组件&#xff0c;它的缺失会导…

HarmonyOS 鸿蒙应用开发(十、第三方开源js库移植适配指南)

在前端和nodejs的世界里&#xff0c;有很多开源的js库&#xff0c;通过npm(NodeJS包管理和分发工具)可以安装使用众多的开源软件包。但是由于OpenHarmony开发框架中的API不完全兼容V8运行时的Build-In API&#xff0c;因此三方js库大都需要适配下才能用。 移植前准备 建议在适…

RabbitMQ的延迟队列实现[死信队列](笔记二)

上一篇已经讲述了实现死信队列的rabbitMQ服务配置&#xff0c;可以点击: RabbitMQ的延迟队列实现(笔记一) 目录 搭建一个新的springboot项目模仿订单延迟支付过期操作启动项目进行测试 搭建一个新的springboot项目 1.相关核心依赖如下 <dependency><groupId>org.…

Linux | 进度条 | Linux简单小程序 | 超级简单 | 这一篇就够了

进度条—实例示范 在学习了基本的Linux指令&#xff0c;Linux上vim编译器等等之后&#xff0c;我们就来学习写代码喽~ 今天就给大家详细讲解一下进度条的编写&#xff0c;需要的效果如下图&#xff1a; 进度条—必备知识 回车和换行 在我们学习编程语言中&#xff0c;经常…

【力扣 - 回文链表】

题目描述 给你一个单链表的头节点 head &#xff0c;请你判断该链表是否为回文链表。如果是&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 提示&#xff1a; 链表中节点数目在范围[1, 100000] 内 0 < Node.val < 9 方法一&#xff1a;将值复制到数…