生成式AI系列 —— DCGAN生成手写数字

news2024/9/23 14:38:47

1、模型构建

1.1 构建生成器

# 导入软件包
import torch
import torch.nn as nn

class Generator(nn.Module):

    def __init__(self, z_dim=20, image_size=256):
        super(Generator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, image_size * 32,
                               kernel_size=4, stride=1),
            nn.BatchNorm2d(image_size * 32),
            nn.ReLU(inplace=True))

        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 32, image_size * 16,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 16),
            nn.ReLU(inplace=True))

        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 16, image_size * 8,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 8),
            nn.ReLU(inplace=True))

        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 8, image_size *4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 4),
            nn.ReLU(inplace=True))
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 4, image_size * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 2),
            nn.ReLU(inplace=True))
        self.layer6 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 2, image_size,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size),
            nn.ReLU(inplace=True))
        self.last = nn.Sequential(
            nn.ConvTranspose2d(image_size, 3, kernel_size=4,
                               stride=2, padding=1),
            nn.Tanh())
        # 注意:因为是黑白图像,所以只有一个输出通道

    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.last(out)

        return out
    
if __name__ == "__main__":
    import matplotlib.pyplot as plt

    G = Generator(z_dim=20, image_size=256)

    # 输入的随机数
    input_z = torch.randn(1, 20)

    # 将张量尺寸变形为(1,20,1,1)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)

    #输出假图像
    fake_images = G(input_z)
    print(fake_images.shape)
    img_transformed = fake_images[0].detach().numpy().transpose(1, 2, 0)
    plt.imshow(img_transformed)
    plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M0oWDbXr-1692468683782)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\Figure_1.png)]

1.1 构建判别器

class Discriminator(nn.Module):

    def __init__(self, z_dim=20, image_size=256):
        super(Discriminator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, image_size, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
       #注意:因为是黑白图像,所以输入通道只有一个

        self.layer2 = nn.Sequential(
            nn.Conv2d(image_size, image_size*2, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))

        self.layer3 = nn.Sequential(
            nn.Conv2d(image_size*2, image_size*4, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(image_size*4, image_size*8, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.layer5 = nn.Sequential(
            nn.Conv2d(image_size*8, image_size*16, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.layer6 = nn.Sequential(
            nn.Conv2d(image_size*16, image_size*32, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.last = nn.Conv2d(image_size*32, 1, kernel_size=4, stride=1)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.last(out)

        return out

    
if __name__ == "__main__":
    #确认程序执行
    D = Discriminator(z_dim=20, image_size=64)

    #生成伪造图像
    input_z = torch.randn(1, 20)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
    fake_images = G(input_z)

    #将伪造的图像输入判别器D中
    d_out = D(fake_images)

    #将输出值d_out乘以Sigmoid函数,将其转换成0~1的值
    print(torch.sigmoid(d_out))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SEGvF8xh-1692468683784)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230817224333376.png)]

2、数据集构建

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import time
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nn

from torchvision import transforms
from model.DCGAN import Generator, Discriminator
from matplotlib import pyplot as plt


def make_datapath_list(root):
    """创建用于学习和验证的图像数据及标注数据的文件路径列表。 """

    train_img_list = list() #保存图像文件的路径

    for img_idx in range(200):
        img_path = f"{root}/img_7_{str(img_idx)}.jpg"
        train_img_list.append(img_path)

        img_path = f"{root}/img_8_{str(img_idx)}.jpg"
        train_img_list.append(img_path)

    return train_img_list


class ImageTransform:
    """图像的预处理类"""

    def __init__(self, mean, std):
        self.data_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(mean, std)]
        )

    def __call__(self, img):
        return self.data_transform(img)


class GAN_Img_Dataset(data.Dataset):
    """图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""

    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        '''返回图像的张数'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''获取经过预处理后的图像的张量格式的数据'''

        img_path = self.file_list[index]
        img = Image.open(img_path)  # [ 高度 ][ 宽度 ] 黑白

        # 图像的预处理
        img_transformed = self.transform(img)

        return img_transformed


# 创建DataLoader并确认执行结果

# 创建文件列表
root = "./img_78"
train_img_list = make_datapath_list(root)

# 创建Dataset
mean = (0.5)
std = (0.5)
train_dataset = GAN_Img_Dataset(
    file_list=train_img_list, transform=ImageTransform(mean, std)
)

# 创建DataLoader
batch_size = 2
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

# 确认执行结果
batch_iterator = iter(train_dataloader)  # 转换为迭代器
imges = next(batch_iterator)  # 取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

数据请在访问链接获取:

3、train接口实现


def train_model(G, D, dataloader, num_epochs):
    # 确认是否能够使用GPU加速
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用设备:", device)

    # 设置最优化算法
    g_lr, d_lr = 0.0001, 0.0004
    beta1, beta2 = 0.0, 0.9
    g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
    d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])

    # 定义误差函数
    criterion = nn.BCEWithLogitsLoss(reduction='mean')

    # 使用硬编码的参数
    z_dim = 20
    mini_batch_size = 8

    # 将网络载入GPU中
    G.to(device)
    D.to(device)

    G.train()  # 将模式设置为训练模式
    D.train()  # 将模式设置为训练模式

    # 如果网络相对固定,则开启加速
    torch.backends.cudnn.benchmark = True

    # 图像张数
    num_train_imgs = len(dataloader.dataset)
    batch_size = dataloader.batch_size

    # 设置迭代计数器
    iteration = 1
    logs = []

    # epoch循环
    for epoch in range(num_epochs):
        # 保存开始时间
        t_epoch_start = time.time()
        epoch_g_loss = 0.0  # epoch的损失总和
        epoch_d_loss = 0.0  # epoch的损失总和

        print('-------------')
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-------------')
        print('(train)')

        # 以minibatch为单位从数据加载器中读取数据的循环
        for imges in dataloader:
            # --------------------
            # 1.判别器D的学习
            # --------------------
            # 如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免
            if imges.size()[0] == 1:
                continue

            # 如果能使用GPU,则将数据送入GPU中
            imges = imges.to(device)

            # 创建正确答案标签和伪造数据标签
            # 在epoch最后的迭代中,小批次的数量会减少
            mini_batch_size = imges.size()[0]
            label_real = torch.full((mini_batch_size,), 1).to(device)
            label_fake = torch.full((mini_batch_size,), 0).to(device)

            # 对真正的图像进行判定
            d_out_real = D(imges)

            # 生成伪造图像并进行判定
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images = G(input_z)
            d_out_fake = D(fake_images)

            # 计算误差
            d_loss_real = criterion(d_out_real.view(-1), label_real.to(torch.float))
            d_loss_fake = criterion(d_out_fake.view(-1), label_fake.to(torch.float))
            d_loss = d_loss_real + d_loss_fake

            # 反向传播处理
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            d_loss.backward()
            d_optimizer.step()

            # --------------------
            # 2.生成器G的学习
            # --------------------
            # 生成伪造图像并进行判定
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images = G(input_z)
            d_out_fake = D(fake_images)

            # 计算误差
            g_loss = criterion(d_out_fake.view(-1), label_real.to(torch.float))

            # 反向传播处理
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            # --------------------
            # 3.记录结果
            # --------------------
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            iteration += 1

        # epoch的每个phase的loss和准确率
        t_epoch_finish = time.time()
        print('-------------')
        print(
            'epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(
                epoch, epoch_d_loss / batch_size, epoch_g_loss / batch_size
            )
        )
        print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))
        t_epoch_start = time.time()

    return G, D

4、训练


G = Generator(z_dim=20, image_size=64)
D = Discriminator(z_dim=20, image_size=64)
# 定义误差函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
num_epochs = 200
G_update, D_update = train_model(
    G, D, dataloader=train_dataloader, num_epochs=num_epochs
)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EvqY6h3G-1692468683786)(E:\学习笔记\深度学习笔记\生成模型\GAN\DCGAN.assets\image-20230820020234209.png)]

5、测试

# 将生成的图像和训练数据可视化
# 反复执行本单元中的代码,直到生成感觉良好的图像为止

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)

# 生成图像
fake_images = G_update(fixed_z.to(device))

# 训练数据
imges = next(iter(train_dataloader))  # 取出位于第一位的元素

# 输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):
    # 将训练数据放入上层
    plt.subplot(2, 5, i + 1)
    plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')

    # 将生成数据放入下层
    plt.subplot(2, 5, 5 + i + 1)
    plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/901194.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于web的停车场收费管理系统/基于springboot的停车场管理系统

摘 要 随着汽车工业的迅猛发展,我国汽车拥有量急剧增加。停车场作为交通设施的组成部分,随着交通运输的繁忙和不断发展,人们对其管理的要求也不断提高,都希望管理能够达到方便、快捷以及安全的效果。停车场的规模各不相同,对其进行管理的模…

深入理解ASP.NET Core中的Program类和Startup类

一、背景介绍 本文以ASP.NET Core 6以前版本API程序来说明。 在我们新建ASP.NET Core项目时,项目根目录下会自动建立Program.cs和Startup.cs两个类文件。 Program.cs 作为 Web 应用程序的默认入口,不做任何修改的情况下,会调用同目录下 Star…

Dubbo 融合 Nacos 成为注册中心

快速上手 Dubbo 融合 Nacos 成为注册中心的操作步骤非常简单,大致步骤可分为“增加 Maven 依赖”以及“配置注册中心“。 增加 Maven 依赖 只需要依赖Dubbo客户端即可,关于推荐的使用版本,请参考Dubbo官方文档或者咨询Dubbo开发人员&#…

Vue 2 组件基础

一个简单的组件示例&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</titl…

暴力模拟入门+简单:零件组装、塔子的签到题、塔子哥考试、平均像素值、换座位

暴力模拟入门 P1038 小红书-2022.9.23-零件组装 #include <bits/stdc.h> #include <cstdint> using namespace std;typedef long long LL; const int N 100001; int num[4]; LL d; vector<vector<LL>> v(4, vector<LL>(N));int main() {for(in…

python中的__name__是个啥?

pycharm中随便打开一个文件&#xff0c;在special variables中能看到一个__name__的变量 在很多python脚本中&#xff0c;也经常能看到if name "main"这样一行 所以_name_到底是个啥&#xff1f; 首先&#xff0c;我们可以确定这是一个str字符变量 “在 Python 中&…

06_布隆过滤器BloomFilter

06——布隆过滤器BloomFilter 一、是什么 由一个初始值都为零的bit数组和多个哈希函数构成&#xff0c;用来快速判断集合中是否存在某个元素 设计思想&#xff1a; 1. 目的&#xff1a;减少内存占用 1. 方式&#xff1a;不保存数据信息&#xff0c;只是在内存中做一个是否存…

【框架类】—MVVM框架

一、MVVM框架有哪些 Vue.jsReact.jsAngular.js 二、对MVVM的认识 1. MVC是什么 全称 Model View Controller, 它采用模型(Model)-视图(View)-控制器(controller)的方法把业务逻辑、数据与界面显示分离 2. MVVM的定义 MVVM是一种软件架构模式&#xff0c;它代表了模型 --视…

智慧工地监管一体化云平台源码 PC端、 手机端、 现场端

智慧工地管理平台是以物联网、移动互联网技术为基础&#xff0c;充分应用大数据、人工智能、移动通讯、云计算等信息技术&#xff0c;利用前端信息采通过人机交互、感知、决策、执行和反馈等&#xff0c;实现对工程项目內人员、车辆、安全、设备、材料等的智能化管理&#xff0…

Python 潮流周刊#16:优雅重要么?如何写出 Pythonic 的代码?

你好&#xff0c;我是猫哥。这里每周分享优质的 Python、AI 及通用技术内容&#xff0c;大部分为英文。标题取自其中两则分享&#xff0c;不代表全部内容都是该主题&#xff0c;特此声明。 本周刊由 Python猫 出品&#xff0c;精心筛选国内外的 250 信息源&#xff0c;为你挑选…

Linux(入门篇)

Linux&#xff08;入门篇&#xff09; Linux概述Linux是什么Linux的诞生Linux和Unix的渊源GNU/LinuxLinux的发行版Linux VS Windows Linux概述 Linux是什么 Linux是一个操作系统(OS) Linux的诞生 作者&#xff1a;李纳斯托瓦兹&#xff08;git也是他开发的&#x1f602;&am…

11. 实现业务功能--获取用户信息

目录 1. 实现 Controller 2. 单体测试 3. 修复返回值存在的缺陷 3.1 用户的隐私数据&#xff1a;密码的密文和盐不能显示 3.2 将值为 null 的字段可以进行过滤 3.3 时间的格式需要进行处理&#xff0c;如 yyyy-mmmm-ddd HH:mm:ss 3.4 data 属性没有返回 4. 实现前端页…

低代码平台全套源码,支持二次开发

低代码开发平台&#xff1a;只需要编写简单的配置文件即可构建企业级应用程序。 一、低代码PaaS平台可以在云端开发、部署、运行低代码应用程序。使用独立数据库模型&#xff0c;基于Kubernetes云原生技术&#xff0c;每个租户均可拥有一套独立的存储、数据库、代码和命名空间&…

光栅化之扫描填充三角形

重心坐标计算 重心坐标比较简单&#xff0c;取最大包围合再计算点是否在三角形内就行&#xff0c;再根据重心坐标返回的alpha,beta,gamma三个权重值计算 uv映射和depth深度缓冲值&#xff0c;因为是求的重心坐标&#xff0c;感觉效果比插值的要好一点。 求重心坐标 barycentr…

Qt 编译使用Bit7z库接口调用7z.dll、7-Zip.dll解压压缩常用Zip、ISO9660、Wim、Esd、7z等格式文件(一)

bit7z一个c静态库&#xff0c;为7-zip共享库提供了一个干净简单的接口 使用CMAKE重新编译github上的bit7z库&#xff0c;用来解压/预览iso9660&#xff0c;WIm&#xff0c;Zip,Rar等常用的压缩文件格式。z-zip库支持大多数压缩文件格式 导读 编译bit7z(C版本)使用mscv 2017编译…

C# 把dll打包到exe文件,真的可以 。文件批量转了ANSI编码

在 C# 中&#xff0c;将 DLL 文件打包到 EXE 文件中可以使用 ILRepack 工具。ILRepack 是一个开源的工具&#xff0c;可以合并多个 DLL 文件并将它们嵌入到一个 EXE 文件中&#xff0c;从而实现将 DLL 打包到 EXE 的功能。 以下是使用 ILRepack 工具打包 DLL 到 EXE 的步骤&…

CSDN今日热榜词云图

文章目录 C云原生人工智能和Python前沿技术软件工程后端javajavascriptphp区块链大数据移动开发嵌入式开发工具数据结构与算法微软技术测试游戏网络运维 C C果然还是应试语言&#xff0c;真题的占比竟然这么大。C之所以没出现&#xff0c;很有可能是在做词云的时候把加号当作非…

多人联机对战游戏赛道,你准备好了吗?

用户日益增长的精神需求和社交娱乐需要&#xff0c;让联机对战的需求与日剧增。 硬件和网络技术的高速发展&#xff0c;也使得联机游戏的体验越来越好。 可以看到&#xff0c;越来越多的联机对战游戏登上游戏榜单。 联机对战已逐渐成为主流&#xff0c;无论在哪个游戏榜单&…

二,MySQL数据库主从复制的介绍及搭建(收藏)

一&#xff0c;介绍概述 主从复制是指将主数据库的 DDL 和 DML 操作通过二进制日志传到从库服务器中&#xff0c;然后在从库上对这些日志重新执行&#xff08;也叫重做&#xff09;&#xff0c;从而使得从库和主库的数据保持同步。 DDL&#xff1a;数据定义语言&#xff0c;用…