MATLAB代码解析:利用DCGAN实现图像数据的生成

news2025/2/25 11:00:19

摘要

经典代码:利用DCGAN生成花朵

MATLAB官方其实给出了DCGAN生成花朵的示范代码,原文地址:训练生成对抗网络 (GAN) - MATLAB & Simulink - MathWorks 中国

先看看训练效果

训练1周期

训练11周期

训练56个周期

脚本文件 

为了能让各位更好的复现,该代码已打包,下载后解压运行用MATLAB运行"gan.mlx"即可
链接: https://pan.baidu.com/s/1hNYLw1xku2AdKf5CanoFzA?pwd=fb7n 提取码: fb7n 
 

代码详解:

首先是脚本gan:

数据获取
clear all
clc
imageFolder = fullfile("flower_photos");
imds = imageDatastore(imageFolder,IncludeSubfolders=true);
augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
生成器
filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];%

layersGenerator = [
    featureInputLayer(numLatentInputs)
    projectAndReshapeLayer(projectionSize)
    transposedConv2dLayer(filterSize,4*numFilters)
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
    tanhLayer];
netG = dlnetwork(layersGenerator);
判别器
dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")
    dropoutLayer(dropoutProb)
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(4,1)
    sigmoidLayer];
netD = dlnetwork(layersDiscriminator);
指定训练选项
numEpochs = 500;
miniBatchSize = 128;
learnRate = 0.00008;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
flipProb = 0.35;
validationFrequency = 100;
训练模型
augimds.MiniBatchSize = miniBatchSize;

mbq = minibatchqueue(augimds, ...
    MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard", ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat="SSCB");
trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];
numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");
ZValidation = dlarray(ZValidation,"CB");
if canUseGPU
    ZValidation = gpuArray(ZValidation);
end

f = figure;
f.Position(3) = 2*f.Position(3);

imageAxes = subplot(1,2,1);
scoreAxes = subplot(1,2,2);

C = colororder;
lineScoreG = animatedline(scoreAxes,Color=C(1,:));
lineScoreD = animatedline(scoreAxes,Color=C(2,:));
legend("Generator","Discriminator");
ylim([0 1])
xlabel("Iteration")
ylabel("Score")
grid on

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Reset and shuffle datastore.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        X = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the format "CB" (channel, batch). If a GPU is
        % available, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,"single");
        Z = dlarray(Z,"CB");

        if canUseGPU
            Z = gpuArray(Z);
        end

        % Evaluate the gradients of the loss with respect to the learnable
        % parameters, the generator state, and the network scores using
        % dlfeval and the modelLoss function.
        [L,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
            dlfeval(@modelLoss,netG,netD,X,Z,flipProb);
        netG.State = stateG;

        %%show data
        %"epoch"
        %epoch
        %"scoreG-D"
        %[scoreG,scoreD]

        % Update the discriminator network parameters.
        [netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
            trailingAvg, trailingAvgSqD, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);

        % Update the generator network parameters.
        [netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...
            trailingAvgG, trailingAvgSqG, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        % Every validationFrequency iterations, display batch of generated
        % images using the held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            XGeneratedValidation = predict(netG,ZValidation);

            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(XGeneratedValidation));
            I = rescale(I);

            % Display the images.
            subplot(1,2,1);
            image(imageAxes,I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end

        % Update the scores plot.
        subplot(1,2,2)
        scoreG = double(extractdata(scoreG));
        addpoints(lineScoreG,iteration,scoreG);

        scoreD = double(extractdata(scoreD));
        addpoints(lineScoreD,iteration,scoreD);

        % Update the title with training progress information.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        title(...
            "Epoch: " + epoch + ", " + ...
            "Iteration: " + iteration + ", " + ...
            "Elapsed: " + string(D))

        drawnow
    end
end
生成新图像  
numObservations = 4;
ZNew = randn(numLatentInputs,numObservations,"single");
ZNew = dlarray(ZNew,"CB");
if canUseGPU
    ZNew = gpuArray(ZNew);
end

XGeneratedNew = predict(netG,ZNew);

I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

生成器与判别器的设计

脚本gan中以及包含了生成器Generator和判别器Discriminator的结构设计,生成器利用装置卷积对特征进行上采样,最终生成了64*64*3的图像,而判别器则用卷积进行下采样,将输入提取至1*1的格式大小,利用sigmoid作为激活函数,判断输入图像的真假

如何自定义生成对抗网络?很简单,把握上采样和下采样的规模就行,利用MATLAB的DLtool(deep network designer)可以很好的观察到这一点,以刚刚的生成器为例,我们可以观察到,转置卷积后(步幅为2),输出的空间(S)长宽都翻倍,深度对应我们给定的filters数量,因此,我们想要生成特定大小的数据时,修改转置卷积的步幅、卷积核数量以及转置卷积层的数量就行,同时记得在添加的转置卷积层后连接新的BN层和ReLU激活函数。

比如我想生成128*128*3的图片,我只需要将刚刚示例中的其中一个转置卷积核的大小提高至7*7,同时步幅修改成4。或者,我直接添加一层步幅为2的转置卷积层。对于一些数据尺寸为非2倍数问题,如311*171*3,我们可以先生成312*172*3再resize一下,或者你提前将数据预处理成312*172.

注意:定义完网络结构后,要用dlnetwork()函数将layer参数转变成可训练的dlnetwork。

最近比较忙,先在这里停笔了,后面再慢慢补充-24-10-14

数据预处理

自定义模型训练

损失函数与梯度下降

优化器与参数更新

总结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2214688.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据库初体验

这两天我学习了数据库的一点知识,我觉得最大的不同就是数据库的代码只能一行一行的运行。 接下来记录我学的东西吧。 第一步 肯定是一些定义知识啦,就不记录了 有一些写一下,数据库的分类为关系型数据库和非关系型数据库 关系型数据库是把复…

Firefox火狐浏览器打开B站视频时默认静音

文章目录 环境问题解决办法 环境 Windows 11家庭版Firefox浏览器 131.0.2 (64 位) 问题 用Firefox浏览器打开B站的视频时,默认是静音播放的: 而其它浏览器,比如Chrome和Edge,默认是带声音播放的。 虽然不是什么大问题&#xf…

工具篇:(一)MacOS 下使用 Navicat 管理 MySQL 数据库:详细图文教程与常见问题解决

MacOS 下使用 Navicat 管理 MySQL 数据库:详细图文教程与常见问题解决 在这篇文章中,我将分享如何在 macOS 上使用 Navicat 来管理 MySQL 数据库。这是一份详细的教程,包括 Navicat 的下载、安装、配置以及使用步骤,并附上亲测的…

优化UVM环境(二)-将error/fatal红色字体打印,pass绿色字体打印

书接上回: 优化UVM环境(一)-环境结束靠的是timeout,而不是正常的objection结束 将error/fatal红色字体打印,pass绿色字体打印 红色字体的error: 31表示字体颜色是红色 1m表示加粗 绿色字体的pass&#…

高可用之限流-05-slide window 滑动窗口

限流系列 开源组件 rate-limit: 限流 高可用之限流-01-入门介绍 高可用之限流-02-如何设计限流框架 高可用之限流-03-Semaphore 信号量做限流 高可用之限流-04-fixed window 固定窗口 高可用之限流-05-slide window 滑动窗口 高可用之限流-06-slide window 滑动窗口 sen…

ReferenceError: MutationEvent is not defined

解决:关闭tampermonkey(篡改猴)插件后也不可以,移除tampermonkey(篡改猴)插件仔刷新就可以了

Linux:Ubuntu系统开启SSH服务

在Ubuntu上开启SSH服务,可以按照以下步骤进行: 1.安装OpenSSH服务 如果你还没有安装OpenSSH服务,可以使用以下命令安装: sudo apt update sudo apt install openssh-server2. 启动SSH服务 安装完成后,启动SSH服务&a…

Leetcode 721 账户合并

Leetcode 721 账户合并 给定一个列表 accounts,每个元素 accounts[i] 是一个字符串列表,其中第一个元素 accounts[i][0] 是 名称 (name),其余元素是 *emails * 表示该账户的邮箱地址。 现在,我们想合并这些账户。如果两个账户都…

jmeter在beanshell中使用props.put()方法的注意事项

在jmeter中,通常使用beanshell去处理一些属性的设置和获取的操作,而这些操作也是有一定的规则的。 1. 设置属性时,在属性名上要加双引号,这代表它不是一个需要用var去声明的变量 这种设置属性的方式才是有效可行的,在…

[权威出刊|稳定检索]2024年云计算、大数据与计算机应用技术国际会议(CCBDCAT 2024)

2024年云计算、大数据与计算机应用技术国际会议 2024 International Conference on Cloud Computing, Big Data, and Computer Application Technology 【1】大会信息 会议名称:2024年云计算、大数据与计算机应用技术国际会议 会议简称:CCBDCAT 2024 大…

【算法设计与分析】第2关:背包问题

任务描述 设有编号为0、1、2、…、n-1的n个物品,它们的重量分别为w0、w1、…、wn-1,价值分别为p0、p1、…、pn-1,其中wi、pi(0≤i≤n-1)均为正数。  有一个背包可以携带的最大重量不超过W。求解目标:在不…

C++类和对象——第三关

在阅读此篇文章之前,请先阅读博主之前的文章: C类和对象第一关-CSDN博客 C类和对象——第二关-CSDN博客 以便更好的理解本文章。 目录 运算符重载 (一)运算符重载 (二)赋值类运算符函数的重载&#x…

基于EBAZ4205矿板的图像处理:16基于小波变换的图像分解及其重建

基于EBAZ4205矿板的图像处理:17基于小波变换的图像分解及其重建 特别说明 这个项目的代码不会开源,因为这个项目的一大部分是在实习的公司做的,所以仅提供思路和展示,展示一下我的能力。 先看效果 这次让小牛和小绿做模特 经过…

C++模板初阶,只需稍微学习;直接起飞;泛型编程

🤓泛型编程 假设像以前交换两个函数需要,函数写很多个或者要重载很多个;那么有什么办法实现一个通用的函数呢? void Swap(int& x, int& y) {int tmp x;x y;y tmp; } void Swap(double& x, double& y) {doubl…

胤娲科技:AI短视频——创意无界,即梦启航

在这个快节奏的时代,你是否曾梦想过用几秒钟的短视频,捕捉生活中的每一个精彩瞬间?是否曾幻想过,即使没有专业的摄影和剪辑技能,也能创作出令人惊艳的作品? 现在,这一切都不再是遥不可及的梦想。…

微前端学习以及分享

微前端学习以及分享 注:本次分享demo的源码github地址:https://github.com/rondout/micro-frontend 什么是微前端 微前端的概念是由ThoughtWorks在2016年提出的,它借鉴了微服务的架构理念,核心在于将一个庞大的前端应用拆分成多…

从MySQL到OceanBase离线数据迁移的实践

本文作者:玉璁,OceanBase 生态产品技术专家。工作十余年,一直在基础架构与中间件领域从事研发工作。现负责OceanBase离线导数产品工具的研发工作,致力于为 OceanBase 建设一套完善的生态工具体系。 背景介绍 在互联网与云数据库技…

LEAP 瞬移工具场景试点游戏关卡

你是否厌倦了在Unity编辑器中浪费时间浏览大型游戏关卡?不要看得比Leap更远!这个功能强大的编辑器脚本允许您只需单击一下即可即时传输到场景中的任何位置。告别繁琐的手动导航,迎接闪电般快速的关卡设计。有了Leap,你就可以专注于…

Gin框架官方文档详解04:HTTP/2 推送,JSON相关

官方文档:https://gin-gonic.com/zh-cn/docs/ 注:强烈建议没用过Gin的读者先阅读第一节:第一个Gin应用。 目录 一、HTTP/2 推送二、JSONP三、PureJSON四、SecureJSON五、总结 一、HTTP/2 推送 首先,以“04HTTP2server推送”为根目…

linux 时区问题

一、修改系统时间和时区 查看当前下系统时间和时区 timedatectl设置系统时区 ​sudo timedatectl set-timezone <时区>​&#xff0c;例如&#xff1a; sudo timedatectl set-timezone Asia/Shanghai执行成功则没有输出。 更推荐使用 tzselect​ ​命令&#xff0c;…