Paddle 实现DCGAN

news2025/1/11 23:39:39

传统GAN

传统的GAN可以看我的这篇文章:Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现-CSDN博客

DCGAN

DCGAN是适用于图像生成的GAN,它的特点是:

  • 只采用卷积层和转置卷积层,而不采用全连接层
  • 在每个卷积层或转置卷积层之间,插入一个批归一化层和ReLU激活函数

转置卷积层

转置卷积层执行的是转置卷积或反卷积的操作,即它是常规卷积层的反向操作。它接收一个低分辨率的输入,然后将其通过转置滤波器升采样到更高的分辨率。

对于一个卷积层,它的输出大小公式是:

o = \frac{i + 2p - k}{s} + 1

其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示卷积核大小(kernel_size),s表示步长(stride)。也就是说:输出大小 = (输入大小 - 卷积核大小 + 2 × 填充数) ÷ 步长 + 1

而对于一个转置卷积层,它的输出大小公式是:

o = s(i-1)-2p+k+u

 其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示反卷积核大小(kernel_size),s表示步长(stride),u表示输出填充(output padding)。也就是说:输出大小 = (输入大小 - 1) * 步长 - 2*填充 + 反卷积大小 + 输出填充

在paddle中,转置卷积层可以这么定义:

paddle.nn.Conv2DTranspose(in_channels, out_channels, kernel_size, stride, padding)

像卷积层一样,反卷积层的in_channels表示输入通道数(如形如(3, 32, 32)的图片张量的通道数就是3),out_channels表示输出通道数(如把(64, 32, 32)变成3通道的彩色图像(3, 32, 32))。 

代码实现

这里我们采用NWPU-RESISC45数据集,从中选择“freeway”(高速公路)作为训练数据,让机器生成高速公路的图片。这个训练数据内有700张256x256的图片,但由于我的电脑显存不足,因此将图片大小设置为64x64.

先写dataset.py:

import paddle
import numpy as np
from PIL import Image
import os


def getAllPath(path):
    return [os.path.join(path, f) for f in os.listdir(path)]


class FreewayDataset(paddle.io.Dataset):

    def __init__(self, transform=None):
        super().__init__()
        self.data = []
        for path in getAllPath('./freeway'):
            img = Image.open(path)
            img = img.resize((64, 64))
            img = np.array(img, dtype=np.float32).transpose((2, 1, 0))
            if transform is not None:
                img = transform(img)
            self.data.append(img)
        self.data = np.array(self.data, dtype=np.float32)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

然后写训练脚本:

from dataset import FreewayDataset
import paddle
from models import Generator, Discriminator
import numpy as np

dataset = FreewayDataset()
dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)

netG = Generator()
netD = Discriminator()

if 1:
    try:
        mydict = paddle.load('generator.params')
        netG.set_dict(mydict)
        mydict = paddle.load('discriminator.params')
        netD.set_dict(mydict)
    except:
        print('fail to load model')

loss = paddle.nn.BCELoss()

optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)

# 最大迭代epoch
max_epoch = 1000

for epoch in range(max_epoch):
    now_step = 0
    for step, data in enumerate(dataloader):
        ############################
        # (1) 更新鉴别器
        ###########################

        # 清除D的梯度
        optimizerD.clear_grad()

        # 传入正样本,并更新梯度
        pos_img = data
        label = paddle.full([pos_img.shape[0], 1, 1, 1], 1, dtype='float32')
        pre = netD(pos_img)
        loss_D_1 = loss(pre, label)
        loss_D_1.backward()

        # 通过randn构造随机数,制造负样本,并传入D,更新梯度
        noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')
        neg_img = netG(noise)
        label = paddle.full([pos_img.shape[0], 1, 1, 1], 0, dtype='float32')
        pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算
        loss_D_2 = loss(pre, label)
        loss_D_2.backward()

        # 更新D网络参数
        optimizerD.step()
        optimizerD.clear_grad()

        loss_D = loss_D_1 + loss_D_2

        ############################
        # (2) 更新生成器
        ###########################

        # 清除D的梯度
        optimizerG.clear_grad()

        noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')
        fake = netG(noise)
        label = paddle.full((pos_img.shape[0], 1, 1, 1), 1, dtype=np.float32, )
        output = netD(fake)
        # 这个写法没有问题,因为这个loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度
        loss_G = loss(output, label)
        loss_G.backward()

        # 更新G网络参数
        optimizerG.step()
        optimizerG.clear_grad()

        now_step += 1

        ###########################
        # 输出日志
        ###########################
        if now_step % 10 == 0:
            print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')

paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

 最后编写图片生成脚本:

import paddle
from models import Generator
import matplotlib.pyplot as plt

# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)

# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格

# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):
    noise = paddle.randn([1, 100, 1, 1], 'float32')
    img = netG(noise)
    img = img.numpy()[0].transpose((2, 1, 0))  # img.numpy():张量转np数组
    img[img < 0] = 0  # 将img中所有小于0的元素赋值为0

    # 显示图片
    ax.imshow(img)
    ax.axis('off')  # 不显示坐标轴

# 显示图像
plt.show()

经过数次训练,最终的效果如下:

这样看来,至少有点高速公路的感觉了。 

参考

通过DCGAN实现人脸图像生成-使用文档-PaddlePaddle深度学习平台

卷积层和反卷积层输出特征图大小计算_输出特征图大小的计算方法-CSDN博客 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1661310.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何编译不同目录下的两个文件

1.直接编译 2.打包成动静态库进行链接

【Shell脚本】Shell编程之循环语句

目录 一.循环语句 1.for语句的结构 1.1.格式 1.2.实操案例 案例1. 案例2. 案例3. 案例4. 2.while语句的结构 2.1.格式 2.2.实操案例 案例1. 案例2. 案例3. 案例4. 3.until循环命令 3.1.格式 3.2.实操案例 案例1. 二.补充 1.常用转义符 一.循环语句 1.for…

鸿蒙内核源码分析(任务切换篇) | 看汇编如何切换任务

在鸿蒙的内核线程就是任务&#xff0c;系列篇中说的任务和线程当一个东西去理解. 一般二种场景下需要切换任务上下文: 在线程环境下&#xff0c;从当前线程切换到目标线程&#xff0c;这种方式也称为软切换&#xff0c;能由软件控制的自主式切换.哪些情况下会出现软切换呢? 运…

Leetcode—138. 随机链表的复制【中等】(cend函数)

2024每日刷题&#xff08;129&#xff09; Leetcode—138. 随机链表的复制 实现代码 /* // Definition for a Node. class Node { public:int val;Node* next;Node* random;Node(int _val) {val _val;next NULL;random NULL;} }; */class Solution { public:Node* copyRan…

【强训笔记】day18

NO.1 思路&#xff1a;双指针模拟。to_string将数字转化为字符。 代码实现&#xff1a; class Solution { public:string compressString(string param) {int left0,right0,nparam.size();string ret;while(right<n){while(right1<n&&param[right]param[right…

我在洛杉矶采访到了亚马逊云全球首席信息官CISO(L11)!

在本次洛杉矶举办的亚马逊云Re:Inforce全球安全大会中&#xff0c;小李哥作为亚马逊大中华区开发者社区和自媒体代表&#xff0c;跟着亚马逊云安全产品团队采访了亚马逊云首席信息安全官(CISO)CJ Moses、亚马逊副总裁Eric Brandwine和亚马逊云首席高级安全工程师Becky Weiss。 …

iOS--runloop的初步认识

runloop的初步认识 简单认识runloopEvent looprunloop其实就是个对象NSRunloop和CFRunLoopRef的依赖关系runloop与线程runloop moderunloop sourceCFRunLoopSourceCFRunLoopObserverCFRunLoopTimer runloop的实现runloop的获取添加ModeCFRunLoopAddCommonMode 添加Run Loop Sou…

找不到msvcp140.dll无法执行代码的原因分析及修复方法

当用户在尝试运行某些应用程序或游戏时&#xff0c;可能会遇到系统弹出错误提示&#xff0c;显示“找不到msvcp140.dll无法执行代码”这一错误信息&#xff0c;它会导致程序无法正常启动。为了解决这个问题&#xff0c;我经过多次尝试和总结&#xff0c;找到了以下五种解决方法…

【Linux】为什么有僵尸状态,什么是僵尸进程,造成危害以及如何避免“内存泄漏”问题详解

&#x1f490; &#x1f338; &#x1f337; &#x1f340; &#x1f339; &#x1f33b; &#x1f33a; &#x1f341; &#x1f343; &#x1f342; &#x1f33f; &#x1f344;&#x1f35d; &#x1f35b; &#x1f364; &#x1f4c3;个人主页 &#xff1a;阿然成长日记 …

C语言/数据结构——(相交链表)

一.前言 今天在力扣上刷到了一道题&#xff0c;想着和大家一起分享一下这道题——相交链表https://leetcode.cn/problems/intersection-of-two-linked-lists废话不多说&#xff0c;让我们开始今天的分享吧。 二.正文 1.1题目描述 是不是感觉好长&#xff0c;我也这么觉得。哈…

Ubuntu/Linux 安装Docker + PyTorch

文章目录 1. 提前准备2. 安装Docker2.1. 卸载冲突软件&#xff08;非必要&#xff09;2.2. 在Ubuntu系统上添加Docker的官方GPG密钥2.3. 将Docker的仓库添加到Ubuntu系统的APT源列表中2.4. 安装最新Docker2.5. 检查 3. 安装Nvidia Container Toolkit3.1. 在Ubuntu系统上添加官方…

WebRtc 视频通话,语音通话实现方案

先了解一下流程 和 流程图(chatGpt的回答) 实现 (底层代码实现, 可作为demo熟悉) 小demo <template><div><video ref"localVideo" autoplay muted></video> <!-- 本地视频元素&#xff0c;用于显示本地视频 --><video ref"r…

vivado 配置存储器支持-Artix-7 配置存储器器件

配置存储器支持 本章主要讲解 Vivado 软件支持的各种非易失性器件存储器。请使用本章作为指南 &#xff0c; 按赛灵思系列、接口、制造商、 密度和数据宽度来为您的应用选择适用的配置存储器器件。 Artix-7 配置存储器器件 下表所示闪存器件支持通过 Vivado 软件对 A…

布局全球内容生态,酷开科技Coolita AIOS以硬核品质亮相

当前&#xff0c;全球产业链供应链格局持续重构&#xff0c;成为影响中国对外经济发展的重要因素。2024年4月15至5月5日&#xff0c;历史久、规模大、层次高&#xff0c;作为中国外贸风向标的第135届中国进出口商品交易会&#xff08;即广交会&#xff09;在美丽的广州隆重举行…

mysql基础概念

文章目录 登录mysqlmysql和mysqld数据库操作主流数据库MYSQL架构SQL分类 登录mysql 登录mysql连接服务器&#xff0c;mysql连接时可以指明主机用-h选项&#xff0c;然后就可以指定主机Ip地址&#xff0c;-P可以指定端口号 -u指定登录用户 -P指定登录密码 查看系统中有无mysql&…

linux上Redis安装使用

环境centOS8 redis是缓存数据库&#xff0c;主要是用于在内存中存储数据&#xff0c;内存的读写很快&#xff0c;加快系统读写数据库的速度 一、Linux 安装 Redis 1. 下载Redis 官网下载Downloads - Redis 历史版本Index of /releases/ 本文中安装的版本为&#xff1a;h…

Oracle体系结构初探:闪回技术

在Oracle体系结构初探这个专栏中&#xff0c;已经写过了REDO、UNDO等内容。觉得可以开始写下有关备份恢复的内容。闪回技术 — Oracle数据库备份恢复机制的一种。它可以在一定条件下&#xff0c;高效快速的恢复因为逻辑错误&#xff08;误删误更新等&#xff09;导致的数据丢失…

动手学深度学习——多层感知机

1. 感知机 感知机本质上是一个二分类问题。给定输入x、权重w、偏置b&#xff0c;感知机输出&#xff1a; 以猫和狗的分类问题为例&#xff0c;它本质上就是找到下面这条黑色的分割线&#xff0c;使得所有的猫和狗都能被正确的分类。 与线性回归和softmax的不同点&#xff1…

服务丢在tomcat中启动war包,需要在tomcat中配置Java环境吗?

一般来说&#xff0c;部署在 Tomcat 上的 WAR 包启动时不需要在 Tomcat 中单独配置 Java 环境&#xff0c;因为 Tomcat 启动本身就需要依赖 Java 环境。以下是确保 Tomcat 正常运行与部署 WAR 包的基本步骤&#xff1a; 安装 Java 环境&#xff1a; 首先&#xff0c;确保你的系…