在kaggle中用GPU使用CGAN生成指定mnist手写数字

news2024/11/25 21:51:10

文章目录

  • 1项目介绍
  • 2参考文章
  • 3代码的实现过程及对代码的详细解析
    • 独热编码
    • 定义生成器
    • 定义判别器
    • 打印我们的引导信息
    • 模型训练
    • 迭代过程中生成的图片
    • 损失函数的变化
  • 4总结
  • 5 模型相关的文件

1项目介绍

在GAN的基础上进行有条件的引导生成图片cgan

2参考文章

GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

3代码的实现过程及对代码的详细解析


import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image

独热编码


# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):
    return torch.eye(class_count)[x, :]  # 切片选取,第一维选取第x个,第二维全要

torch.eye(10)函数的作用是生成一个10*10的对角矩阵
该函数的作用是得到第x个位置为1的独热编码,如果传入为列表,则得到一个矩阵
在这里插入图片描述

 
transform =transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5)])
 #minist数据集中的图片数据的维度是[batch_size, 1, 28, 28],其中batch_size是每个批次的图像数量。这个数据集中的每个图像都是28x28像素的灰度图像,因此它们只有一个通道
dataset = torchvision.datasets.MNIST('data',
                                     train=True,
                                     transform=transform,
                                     target_transform=one_hot,
                                     download=True)
#这里target_transform参数的作用是对标签进行转换。在这个例子中,它的作用是将标签转换为one-hot编码。
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)
 

定义生成器


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        #因此,这个函数的输入张量维度为[batch_size, 10]和[batch_size, 100],输出张量维度为[batch_size, 1, 1, 1]。
        self.linear1 = nn.Linear(10, 128 * 7 * 7)
        self.bn1 = nn.BatchNorm1d(128 * 7 * 7)
        self.linear2 = nn.Linear(100, 128 * 7 * 7)
        self.bn2 = nn.BatchNorm1d(128 * 7 * 7)
        #这个函数的作用是将一个输入张量进行反卷积操作,得到一个输出张量。
        #nn.ConvTranspose2d函数的作用是将一个256通道的输入张量转换为一个128通道的输出张量,使用3x3的卷积核进行卷积操作,并在卷积操作后进行1像素的padding
        self.deconv1 = nn.ConvTranspose2d(256, 128,
                                          kernel_size=(3, 3),
                                          padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 1,
                                          kernel_size=(4, 4),
                                          stride=2,
                                          padding=1)
 
    def forward(self, x1, x2):
        x1 = F.relu(self.linear1(x1))
        x1 = self.bn1(x1)
        x1 = x1.view(-1, 128, 7, 7)
        x2 = F.relu(self.linear2(x2))
        x2 = self.bn2(x2)
        x2 = x2.view(-1, 128, 7, 7)
        #将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量
        x = torch.cat([x1, x2], axis=1)
        x = F.relu(self.deconv1(x))
        #形状变为为[64, 128, 7, 7]的张量
        x = self.bn3(x)
        x = F.relu(self.deconv2(x))
        #形状变为为[64, 64, 14, 14]的张量
        x = self.bn4(x)
         # 形状变为为[64, 1, 28, 28]的张量
        x = torch.tanh(self.deconv3(x))
      
        return x

生成器对数据的处理过程:
这个函数对于输入张量[64, 1, 28, 28]的维度变化过程如下:
输入张量维度为[64, 1, 28, 28]
经过线性变换和ReLU激活函数处理后,得到两个形状为[64, 128 * 7 * 7]的张量
将两个张量分别通过BatchNorm1d进行归一化处理
将两个处理后的结果reshape成形状为[64, 128, 7, 7]的张量
将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量
经过反卷积操作得到输出张量,维度为[64, 1, 28, 28]

定义判别器


# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear = nn.Linear(10, 1*28*28)
        self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值
 
    def forward(self, x1, x2):
    #leak_relu激活函数:它在输入小于0时返回一个小的斜率,而在输入大于等于0时返回输入本身
        x1 =F.leaky_relu(self.linear(x1))
        x1 = x1.view(-1, 1, 28, 28)
        #torch.cat([x1, x2], axis=1)函数将张量x1和张量x2沿着第二个维度(即列)拼接起来
        x = torch.cat([x1, x2], axis=1)
        #处理过后变为(64,2,28,28)
        x = F.dropout2d(F.leaky_relu(self.conv1(x)))
        #维度变为(64,64,13,13)
        x = F.dropout2d(F.leaky_relu(self.conv2(x)))
        #维度变为(64,128,6,6)
        x = self.bn(x)
        x = x.view(-1, 128*6*6)
        #最后键位了64*1(同时把值映射到0~1之间)
        x = torch.sigmoid(self.fc(x))
        return x


# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
 
# 损失计算函数
loss_function = torch.nn.BCELoss()
 
# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
 
# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):
	#生成器生成取片,label_input为输入的引导信息,noise_input为随机的噪声点
    predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())
    #numpy.squeeze()函数的作用是去掉矩阵里维度为1的维度。
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow((predictions[i] + 1) / 2, cmap='gray')
        plt.axis("off")
    from IPython.display import FileLink
    plt.savefig('data/img/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


import os 
os.makedirs("data/img")

打印我们的引导信息

noise_seed = torch.randn(16, 100, device=device)
 
label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)

print(label_seed)
tensor([1, 3, 5, 4, 9, 3, 0, 0, 1, 3, 4, 5, 9, 2, 3, 7])

模型训练

D_loss = []
G_loss = []


# 训练循环
for epoch in range(150):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader.dataset)
    # 对全部的数据集做一次迭代
    #dataloader中的图像是四维的。在for循环中,每次迭代会返回一个batch_size大小的数据
    #其中每个数据都是一个四维张量,形状为[batch_size, channels, height, width]
    for step, (img, label) in enumerate(dataloader):
        img = img.to(device)
        label = label.to(device)
        size = img.shape[0]
        random_noise = torch.randn(size, 100, device=device)
 
        d_optim.zero_grad()
 
        real_output = dis(label, img)
        d_real_loss = loss_function(real_output,
                                    torch.ones_like(real_output, device=device)
                                    )
        #torch.ones_like(real_output, device=device)函数的作用是生成一个与real_output形状相同的张量,其中所有元素都为1。                         
        d_real_loss.backward() #求解梯度
 
        # 得到判别器在生成图像上的损失
        gen_img = gen(label,random_noise)
        fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果
        d_fake_loss = loss_function(fake_output,
                                    torch.zeros_like(fake_output, device=device))
        d_fake_loss.backward()
 
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()  # 优化
 
        # 得到生成器的损失
        g_optim.zero_grad()
        fake_output = dis(label, gen_img)
        g_loss = loss_function(fake_output,
                               torch.ones_like(fake_output, device=device))
        g_loss.backward()
        g_optim.step()
 
        with torch.no_grad():
            d_epoch_loss += d_loss.item()
            g_epoch_loss += g_loss.item()
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        if epoch % 10 == 0:
            print('Epoch:', epoch)
            generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)
    print("epoch:{}/150".format(epoch))
 
plt.plot(D_loss, label='D_loss')
plt.plot(G_loss, label='G_loss')
plt.legend()
plt.show()

迭代过程中生成的图片

迭代1次
在这里插入图片描述
迭代10次
在这里插入图片描述
迭代20次
在这里插入图片描述

迭代30次
在这里插入图片描述

迭代40次
在这里插入图片描述

迭代150次
在这里插入图片描述

损失函数的变化

在这里插入图片描述

4总结

cGAN相比于GAN而言,将label的信息通过一系列的卷积操作和图像的信息融合在一起,然后放进模型进行训练,让我们的模型能和label相匹配的图像,从而在我们给出制定的数字label时能够生成对应的数字图片,实现了引导的过程。

5 模型相关的文件

模型的相关文件:提取码(ujki)

本模型是放在kaggle中运行的,kaggle的部署流程请参考:在kaggle中用GPU训练模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/941212.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

android framework之Applicataion启动流程分析

Application启动流程分析 启动方式一:通过Launcher启动app 启动方式二:在某一个app里启动第二个app的Activity. 以上两种方式均可触发app进程的启动。但无论哪种方式,最终通过通过调用AMS的startActivity()来启动application的。 根据上图…

家政服务行业搭建小程序的实用技巧分享

随着移动互联网的发展,小程序成为了各行各业的新宠。对于家政服务行业来说,搭建一个小程序商城可以极大地提升服务的便捷性和用户体验,同时也能提高企业的竞争力。本文将分享家政服务行业搭建小程序的实用技巧,帮助您顺利创建属于…

利用深度蛋白质序列嵌入方法通过 Siamese neural network 对 virus-host PPIs 进行精准预测【Patterns,2022】

研究背景: 病毒感染可以导致多种组织特异性损伤,所以 virus-host PPIs 的预测有助于新的治疗方法的研究;目前已有的一些 virus-host PPIs 鉴定或预测方法效果有限(传统实验方法费时费力、计算方法要么基于蛋白结构或基因&#xff…

SAP-FI-会计凭字段替代OBBH

会计凭证替代OBBH 业务:文本必须等于某个字段的值,例如凭证日期 关闭确认功能,输入OBBH 双击“替代”进入功能配置,或者用GGB1,用GGB1的功能更多。 点击行项目,点击“新建替换”保存 点击新建YXL7331,点击…

删除命名空间一直处于Terminating

删除命名空间一直处于Terminating 通常删除命名空间或者其他资源一直处于Terminating状态,是由于资源调度到的节点处于NotReady状态,需要将节点重新加入到集群使其状态变为Ready状态才能解决问题,当node重新加入处于Ready状态后,…

系统报错msvcr120.dll丢失一键修复教程,快速修复dll报错问题

今天,我将和大家探讨一个常见的问题:系统报错msvcr120.dll丢失。这个问题相信很多网友都遇到过,尤其是在使用一些较老的软件或者游戏时,很容易出现这个错误。那么,如何解决这个问题呢?下面,我将…

Matlab(结构化程式和自定义函数)

目录 1.脚本编辑器 2.脚本流 2.1 控制流 2.2 关系(逻辑)操作符 3.脚本与函数 1.脚本编辑器 Matlab的命名规则: 常用功能: 智能缩进: 在写代码的时候,有的时候代码看起来并不是那么美观(可读性…

在线查询让家长迅速获得录取通知书

发布录取通知书是一项看似简单却非常耗时费力的工作。负责录取工作的老师通常会采取以下常见的发放方式: 1. 面试告知:某些学校会在面试结束后立即告知学生是否被录取。这种方式通常适用于面试人数较少的学校或特定专业。 2. 电子邮件:学校通…

pytorch中torch.gather()简单理解

1.作用 从输入张量中按照指定维度进行索引采集操作,返回值是一个新的张量,形状与 index 张量相同,根据指定的索引从输入张量中采集对应的元素。 2.问题 该函数的主要问题主要在dim维度上,dim0 表示沿着第一个维度(行…

P21~22 第六章 储能元件——电容存储电场能,电感存储磁场能

1、电容元件 a定义 b线性时不变电容元件 c电容的电压与电流关系 i有限则u有限 注意理解面积 d电容的功率和储能 e例一 跃变就是指物体的物理量从有限值变为无限值的过程。 分析上图例题:对于电源波形要吃负无穷到正无穷去刻画。即时间轴要铺满。 有有图控制电…

Mysql001:Mysql概述以及安装

前言:本课程将从头学习Mysql,以我的工作经验来说,sql语句真的太重要的,现在互联网所有的一切都是建立在数据上,因为互联网的兴起,现在的数据日月增多,每年都以翻倍的形式增长,对于数…

服务器数据库中了locked勒索病毒怎么办,locked勒索病毒恢复工具

最近一段时间网络上的locked勒索病毒非常嚣张,自从6月份以来,很多企业的计算机服务器数据库遭到了locked勒索病毒的攻击,起初locked勒索病毒攻击用友畅捷通T用户,后来七月份开始攻击金蝶云星空客户,导致企业的财务系统…

【数学建模】清风数模正课4 拟合算法

拟合算法 在插值算法中,我们得到的曲线一定是要经过所有的函数点的;而用拟合所得到的曲线则不一样,拟合问题中,不需要得到的曲线一定经过给定的点。 拟合的目的是寻求一个函数曲线,使得该曲线在某种准则下与所有的数…

AcWing 898. 数字三角形 (每日一题)

大家好 我是寸铁 希望这篇题解对你有用,麻烦动动手指点个赞或关注,感谢您的关注 注意 像数组下标出现i-1的,在循环的时候从i1开始。 关于0x3f3f3f3f和Integer.MAX_VALUE 0x3f3f3f3f:1061109567 Integer.MAX_VALUE:2147483647 在选用Integ…

云计算在大数据分析中的应用与优势

文章目录 云计算在大数据分析中的应用云计算在大数据分析中的优势云计算在大数据分析中的示例未来发展和拓展结论 🎉欢迎来到AIGC人工智能专栏~云计算在大数据分析中的应用与优势 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒🍹✨博客主页:IT陈寒的博客&…

Bootstrap 源代码目录结构一览

目录 前言 Bootstrap 目录结构 Bootstrap 内容简介 Bootstrap 编译文件 CSS文件 | CSS 文件功能对比与清单 JS文件 | JS 文件功能对比与清单 Bootstrap 源码码目录 | 资源清单 前言 Bootstrap是Twitter推出的一个用于前端开发的开源工具包。它由Twitter的设计师Mark Ot…

续1-续3《你的医书是假的!批评付施威的《DDD诊所——聚合过大综合症》

DDD领域驱动设计批评文集 “软件方法建模师”不再考查基础题 《软件方法》各章合集 我写了一篇文章,批评付施威的《DDD诊所——聚合过大综合症》(以下简称《DDD诊所》),文章是《你的医书是假的!批评付施威的《DDD诊…

【JAVA】什么是异常

⭐ 作者:小胡_不糊涂 🌱 作者主页:小胡_不糊涂的个人主页 📀 收录专栏:浅谈Java 💖 持续更文,关注博主少走弯路,谢谢大家支持 💖 异常 1. 什么是异常1.1 概念1.2 异常的体…

第五章 树与二叉树 三、二叉树的先、中、后序遍历

一、定义 树的遍历是按照一定的顺序访问树的所有节点,常用的遍历方式有三种:先序遍历、中序遍历和后序遍历。 先序遍历:从根节点开始,按照根节点-左子树-右子树的顺序遍历整棵树,即先访问根节点,然后遍历左…