原始GAN-pytorch-生成MNIST数据集(代码)

news2024/9/29 18:07:04

文章目录

  • 原始GAN生成MNIST数据集
    • 1. Data loading and preparing
    • 2. Dataset and Model parameter
    • 3. Result save path
    • 4. Model define
    • 6. Training
    • 7. predict

原始GAN生成MNIST数据集

原理很简单,可以参考原理部分原始GAN-pytorch-生成MNIST数据集(原理)

import os
import time
import torch
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import sys 
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

1. Data loading and preparing

测试使用loadlocal_mnist加载数据

from mlxtend.data import loadlocal_mnist
train_data_path = "../data/MNIST/train-images.idx3-ubyte"
train_label_path = "../data/MNIST/train-labels.idx1-ubyte"
test_data_path = "../data/MNIST/t10k-images.idx3-ubyte"
test_label_path = "../data/MNIST/t10k-labels.idx1-ubyte"

train_data,train_label = loadlocal_mnist(
    images_path = train_data_path,
    labels_path = train_label_path
)
train_data.shape,train_label.shape
((60000, 784), (60000,))
import matplotlib.pyplot as plt

img,ax = plt.subplots(3,3,figsize=(9,9))
plt.subplots_adjust(hspace=0.4,wspace=0.4)
for i in range(3):
    for j in range(3):
        num = np.random.randint(0,train_label.shape[0])
        ax[i][j].imshow(train_data[num].reshape((28,28)),cmap="gray")
        ax[i][j].set_title(train_label[num],fontdict={"fontsize":20})
plt.show()

在这里插入图片描述

2. Dataset and Model parameter

构造pytorch数据集datasets和数据加载器dataloader

input_size = [1, 28, 28]
batch_size = 128
Epoch = 1000
GenEpoch = 1
in_channel = 64
from torch.utils.data import Dataset,DataLoader
import numpy as np 
from mlxtend.data import loadlocal_mnist
import torchvision.transforms as transforms

class MNIST_Dataset(Dataset):
    def __init__(self,train_data_path,train_label_path,transform=None):
        train_data,train_label = loadlocal_mnist(
            images_path = train_data_path,
            labels_path = train_label_path
            )
        self.train_data = train_data
        self.train_label = train_label.reshape(-1)
        self.transform=transform
        
    def __len__(self):
        return self.train_label.shape[0] 
    
    def __getitem__(self,index):
        if torch.is_tensor(index):
            index = index.tolist()
        images = self.train_data[index,:].reshape((28,28))
        labels = self.train_label[index]
        if self.transform:
            images = self.transform(images)
        return images,labels

transform_dataset =transforms.Compose([
    transforms.ToTensor()]
)
MNIST_dataset = MNIST_Dataset(train_data_path=train_data_path,
                               train_label_path=train_label_path,
                               transform=transform_dataset)  
MNIST_dataloader = DataLoader(dataset=MNIST_dataset,
                              batch_size=batch_size,
                              shuffle=True,drop_last=False)
img,ax = plt.subplots(3,3,figsize=(9,9))
plt.subplots_adjust(hspace=0.4,wspace=0.4)
for i in range(3):
    for j in range(3):
        num = np.random.randint(0,train_label.shape[0])
        ax[i][j].imshow(MNIST_dataset[num][0].reshape((28,28)),cmap="gray")
        ax[i][j].set_title(MNIST_dataset[num][1],fontdict={"fontsize":20})
plt.show()

在这里插入图片描述

3. Result save path

time_now = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime(time.time()))
log_path = f'./log/{time_now}'
os.makedirs(log_path)
os.makedirs(f'{log_path}/image')
os.makedirs(f'{log_path}/image/image_all')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using device: {device}')
using device: cuda

4. Model define

import torch
from torch import nn 

class Discriminator(nn.Module):
    def __init__(self,input_size,inplace=True):
        super(Discriminator,self).__init__()
        c,h,w = input_size
        self.dis = nn.Sequential(
            nn.Linear(c*h*w,512),  # 输入特征数为784,输出为512
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),  # 进行非线性映射
            
            nn.Linear(512, 256),  # 进行一个线性映射
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            
            nn.Linear(256, 1),
            nn.Sigmoid()  # 也是一个激活函数,二分类问题中,
            # sigmoid可以班实数映射到【0,1】,作为概率值,
            # 多分类用softmax函数
        )
        
        
    def forward(self,x):
        b,c,h,w = x.size()
        x = x.view(b,-1)
        x = self.dis(x)
        x = x.view(-1)
        return x 

class Generator(nn.Module):
    def __init__(self,in_channel):
        super(Generator,self).__init__() # 调用父类的构造方法
        self.gen = nn.Sequential(
            nn.Linear(in_channel, 128),
            nn.LeakyReLU(0.2),
            
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self,x):
        res = self.gen(x)
        return res.view(x.size()[0],1,28,28)

D = Discriminator(input_size=input_size)
G = Generator(in_channel=in_channel)
D.to(device)
G.to(device)
D,G
(Discriminator(
   (dis): Sequential(
     (0): Linear(in_features=784, out_features=512, bias=True)
     (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): LeakyReLU(negative_slope=0.2)
     (3): Linear(in_features=512, out_features=256, bias=True)
     (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (5): LeakyReLU(negative_slope=0.2)
     (6): Linear(in_features=256, out_features=1, bias=True)
     (7): Sigmoid()
   )
 ),
 Generator(
   (gen): Sequential(
     (0): Linear(in_features=64, out_features=128, bias=True)
     (1): LeakyReLU(negative_slope=0.2)
     (2): Linear(in_features=128, out_features=256, bias=True)
     (3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (4): LeakyReLU(negative_slope=0.2)
     (5): Linear(in_features=256, out_features=512, bias=True)
     (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (7): LeakyReLU(negative_slope=0.2)
     (8): Linear(in_features=512, out_features=1024, bias=True)
     (9): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (10): LeakyReLU(negative_slope=0.2)
     (11): Linear(in_features=1024, out_features=784, bias=True)
     (12): Tanh()
   )
 ))

6. Training

criterion = nn.BCELoss()
D_optimizer = torch.optim.Adam(D.parameters(),lr=0.0003)
G_optimizer = torch.optim.Adam(G.parameters(),lr=0.0003)
D.train()
G.train()
gen_loss_list = []
dis_loss_list = []

for epoch in range(Epoch):
    with tqdm(total=MNIST_dataloader.__len__(),desc=f'Epoch {epoch+1}/{Epoch}')as pbar:
        gen_loss_avg = []
        dis_loss_avg = []
        index = 0
        for batch_idx,(img,_) in enumerate(MNIST_dataloader):
            img = img.to(device)
            # the output label
            valid = torch.ones(img.size()[0]).to(device)
            fake = torch.zeros(img.size()[0]).to(device)
            # Generator input
            G_img = torch.randn([img.size()[0],in_channel],requires_grad=True).to(device)
            # ------------------Update Discriminator------------------
            # forward
            G_pred_gen = G(G_img)
            G_pred_dis = D(G_pred_gen.detach())
            R_pred_dis = D(img)
            # the misfit
            G_loss = criterion(G_pred_dis,fake)
            R_loss = criterion(R_pred_dis,valid)
            dis_loss = (G_loss+R_loss)/2
            dis_loss_avg.append(dis_loss.item())
            # backward
            D_optimizer.zero_grad()
            dis_loss.backward()
            D_optimizer.step()
            # ------------------Update Optimizer------------------
            # forward
            G_pred_gen = G(G_img)
            G_pred_dis = D(G_pred_gen)
            # the misfit
            gen_loss = criterion(G_pred_dis,valid)
            gen_loss_avg.append(gen_loss.item())
            # backward
            G_optimizer.zero_grad()
            gen_loss.backward()
            G_optimizer.step()
            # save figure
            if index % 200 == 0 or index + 1 == MNIST_dataset.__len__():
                save_image(G_pred_gen, f'{log_path}/image/image_all/epoch-{epoch}-index-{index}.png')
            index += 1
            # ------------------进度条更新------------------
            pbar.set_postfix(**{
                'gen-loss': sum(gen_loss_avg) / len(gen_loss_avg),
                'dis-loss': sum(dis_loss_avg) / len(dis_loss_avg)
            })
            pbar.update(1)
        save_image(G_pred_gen, f'{log_path}/image/epoch-{epoch}.png')
    filename = 'epoch%d-genLoss%.2f-disLoss%.2f' % (epoch, sum(gen_loss_avg) / len(gen_loss_avg), sum(dis_loss_avg) / len(dis_loss_avg))
    torch.save(G.state_dict(), f'{log_path}/{filename}-gen.pth')
    torch.save(D.state_dict(), f'{log_path}/{filename}-dis.pth')
    # 记录损失
    gen_loss_list.append(sum(gen_loss_avg) / len(gen_loss_avg))
    dis_loss_list.append(sum(dis_loss_avg) / len(dis_loss_avg))
    # 绘制损失图像并保存
    plt.figure(0)
    plt.plot(range(epoch + 1), gen_loss_list, 'r--', label='gen loss')
    plt.plot(range(epoch + 1), dis_loss_list, 'r--', label='dis loss')
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.savefig(f'{log_path}/loss.png', dpi=300)
    plt.close(0)
Epoch 1/1000: 100%|██████████| 469/469 [00:11<00:00, 41.56it/s, dis-loss=0.456, gen-loss=1.17] 
Epoch 2/1000: 100%|██████████| 469/469 [00:11<00:00, 42.34it/s, dis-loss=0.17, gen-loss=2.29] 
Epoch 3/1000: 100%|██████████| 469/469 [00:10<00:00, 43.29it/s, dis-loss=0.0804, gen-loss=3.11]
Epoch 4/1000: 100%|██████████| 469/469 [00:11<00:00, 40.74it/s, dis-loss=0.0751, gen-loss=3.55]
Epoch 5/1000: 100%|██████████| 469/469 [00:12<00:00, 39.01it/s, dis-loss=0.105, gen-loss=3.4]  
Epoch 6/1000: 100%|██████████| 469/469 [00:11<00:00, 39.95it/s, dis-loss=0.112, gen-loss=3.38]
Epoch 7/1000: 100%|██████████| 469/469 [00:11<00:00, 40.16it/s, dis-loss=0.116, gen-loss=3.42]
Epoch 8/1000: 100%|██████████| 469/469 [00:11<00:00, 42.51it/s, dis-loss=0.124, gen-loss=3.41]
Epoch 9/1000: 100%|██████████| 469/469 [00:11<00:00, 40.95it/s, dis-loss=0.136, gen-loss=3.41]
Epoch 10/1000: 100%|██████████| 469/469 [00:11<00:00, 39.59it/s, dis-loss=0.165, gen-loss=3.13]
Epoch 11/1000: 100%|██████████| 469/469 [00:11<00:00, 40.28it/s, dis-loss=0.176, gen-loss=3.01]
Epoch 12/1000: 100%|██████████| 469/469 [00:12<00:00, 37.60it/s, dis-loss=0.19, gen-loss=2.94] 
Epoch 13/1000: 100%|██████████| 469/469 [00:11<00:00, 39.17it/s, dis-loss=0.183, gen-loss=2.95]
Epoch 14/1000: 100%|██████████| 469/469 [00:12<00:00, 38.51it/s, dis-loss=0.182, gen-loss=3.01]
Epoch 15/1000: 100%|██████████| 469/469 [00:10<00:00, 44.58it/s, dis-loss=0.186, gen-loss=2.95]
Epoch 16/1000: 100%|██████████| 469/469 [00:10<00:00, 44.08it/s, dis-loss=0.198, gen-loss=2.89]
Epoch 17/1000: 100%|██████████| 469/469 [00:10<00:00, 45.11it/s, dis-loss=0.187, gen-loss=2.99]
Epoch 18/1000: 100%|██████████| 469/469 [00:10<00:00, 44.98it/s, dis-loss=0.183, gen-loss=3.03]
Epoch 19/1000: 100%|██████████| 469/469 [00:10<00:00, 46.68it/s, dis-loss=0.187, gen-loss=2.98]
Epoch 20/1000: 100%|██████████| 469/469 [00:10<00:00, 46.12it/s, dis-loss=0.192, gen-loss=3]   
Epoch 21/1000: 100%|██████████| 469/469 [00:10<00:00, 46.80it/s, dis-loss=0.193, gen-loss=3.01]
Epoch 22/1000: 100%|██████████| 469/469 [00:10<00:00, 45.86it/s, dis-loss=0.186, gen-loss=3.04]
Epoch 23/1000: 100%|██████████| 469/469 [00:10<00:00, 46.00it/s, dis-loss=0.17, gen-loss=3.2]  
Epoch 24/1000: 100%|██████████| 469/469 [00:10<00:00, 46.41it/s, dis-loss=0.173, gen-loss=3.19]
Epoch 25/1000: 100%|██████████| 469/469 [00:10<00:00, 45.15it/s, dis-loss=0.19, gen-loss=3.1]  
Epoch 26/1000: 100%|██████████| 469/469 [00:10<00:00, 44.26it/s, dis-loss=0.178, gen-loss=3.16]
Epoch 27/1000: 100%|██████████| 469/469 [00:10<00:00, 45.14it/s, dis-loss=0.187, gen-loss=3.17]
Epoch 28/1000:   1%|▏         | 6/469 [00:00<00:12, 38.20it/s, dis-loss=0.184, gen-loss=3.04]



---------------------------------------------------------------------------

7. predict

input_size = [3, 32, 32]
in_channel = 64
gen_para_path = './log/2023-02-11-17_52_12/epoch999-genLoss1.21-disLoss0.40-gen.pth'
dis_para_path = './log/2023-02-11-17_52_12/epoch999-genLoss1.21-disLoss0.40-dis.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator_Transpose(in_channel=in_channel).to(device)
dis = DiscriminatorLinear(input_size=input_size).to(device)
gen.load_state_dict(torch.load(gen_para_path, map_location=device))
gen.eval()
# 随机生成一组数据
G_img = torch.randn([1, in_channel, 1, 1], requires_grad=False).to(device)
# 放入网路
G_pred = gen(G_img)
G_dis = dis(G_pred)
print('generator-dis:', G_dis)
# 图像显示
G_pred = G_pred[0, ...]
G_pred = G_pred.detach().cpu().numpy()
G_pred = np.array(G_pred * 255)
G_pred = np.transpose(G_pred, [1, 2, 0])
G_pred = Image.fromarray(np.uint8(G_pred))
G_pred.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/375132.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LightningChart .NET 10.4.1 NEW Crack

实时监控&#xff0c;无闪烁或延迟 完整的数据准确性&#xff0c;无需减少数据点 屏幕上的更多数据 更好的图形质量 响应式用户界面。鼠标或触摸屏操作将立即更新图表&#xff0c;并为其他 UI 控件释放处理器时间以继续操作 Visual Studio Marketplace 中最受欢迎的 .NET 图表控…

全新后门文件Nev-3.exe分析

一、 样本发现&#xff1a; 蜜罐 二、 内容简介&#xff1a; 通过公司的蜜罐告警发现一个Nev-3.exe可执行文件文件&#xff0c;对该样本文件进行分析发现&#xff0c;该可执行程序执行后会从远程服务器http://194.146.84.2:4395/下载一个名为“3”的压缩包&#xff0c;解压后…

数据结构与算法——3.时间复杂度分析1(概述)

前面我们已经介绍了&#xff0c;研究算法的最终目的是如何花费更少的时间&#xff0c;如何占用更少的内存去完成相同的需求&#xff0c;并且也通过案例演示了不同算法之间时间耗费和空间耗费上的差异&#xff0c;但我们并不能将时间占用和空间占用量化。因此&#xff0c;接下来…

【经验总结】10年的嵌入式开发老手,到底是如何快速学习和使用RT-Thread的?

【经验总结】一位近10年的嵌入式开发老手&#xff0c;到底是如何快速学习和使用RT-Thread的&#xff1f; RT-Thread绝对可以称得上国内优秀且排名靠前的操作系统&#xff0c;在嵌入式IoT领域一直享有盛名。近些年&#xff0c;物联网产业的大热&#xff0c;更是直接将RT-Thread这…

Redis | 安装Redis和启动Redis服务

目录 一、Redis简介 1.1 简介 二、Redis安装 2.1 Windows安装Redis 2.2 Linux安装Redis 三、Redis服务启动和停止 3.1 Windows启动Redis服务 3.2 Linux启动Redis服务 四、Redis设置密码远程连接 4.1 为Redis登陆设置密码 4.2 设置Redis允许远程连接 五、Redis常…

STM32CubeMX按键模块化 点灯

本文代码使用 HAL 库。 文章目录前言一、按键原理图二、CubeMX 创建工程三、代码讲解&#xff1a;1. GPIO的输入HAL库函数&#xff1a;2. 消抖&#xff1a;3. 详细代码四&#xff0c;实验现象&#xff1a;总结前言 我们继续讲解 stm32 f103&#xff0c;这篇文章将详细 为大家讲…

哪个品牌蓝牙耳机性价比高?性价比高的平价蓝牙耳机推荐

现如今&#xff0c;随着蓝牙技术的进步&#xff0c;蓝牙耳机在人们日常生活中的便捷性更胜从前。越来越多的蓝牙耳机品牌被大众看见、认可。那么&#xff0c;哪个品牌的蓝牙耳机性价比高&#xff1f;接下来&#xff0c;我给大家推荐几款性价比高的平价蓝牙耳机&#xff0c;一起…

Idea启动遇到 Web server failed to start. Port 8080 was already in use. 报错

Idea启动遇到问题-记录 报错英文提示&#xff1a; APPLICATION FAILED TO START Description: Web server failed to start. Port 8080 was already in use. Action: Identify and stop the process that’s listening on port 8080 or configure this application to liste…

《C++模板进阶》

致前行的人&#xff1a; 要努力&#xff0c;但不要着急&#xff0c;繁花锦簇&#xff0c;硕果累累都需要过程&#xff01; 目录 前言&#xff1a; 1.非类型模板参数 1.1.概念&#xff1a; 1.2.使用注意事项 2.模板特化 2.1函数模板特化 2.2类模板特化 3.模板的分离编译 3.1什么…

【手撕面试题】JavaScript(高频知识点二)

目录 面试官&#xff1a;请你谈谈JS的this指向问题 面试官&#xff1a;说一说call apply bind的作用和区别&#xff1f; 面试官&#xff1a;请你谈谈对事件委托的理解 面试官&#xff1a;说一说promise是什么与使用方法&#xff1f; 面试官&#xff1a;说一说跨域是什么&a…

Python 之 Pandas 文件操作和读取 CSV 参数详解

文章目录一、Pandas 读取文件二、CSV 文件读取1. 基本参数2. 通用解析参数3. 空值处理相关参数4. 时间处理相关参数5. 分块读入相关参数一、Pandas 读取文件 当使用 Pandas 做数据分析的时&#xff0c;需要读取事先准备好的数据集&#xff0c;这是做数据分析的第一步。Panda 提…

Cocoa-presentViewController

presentViewController:animator: 将一个viewController以动画方式显示出来 当VCA模态的弹出了VCB&#xff0c;那么VCA就是presenting view controller&#xff0c;VCB就是presented view controller presentViewController 相较于addSubView 直接作为subView就是不会出现一…

VUE的安装和创建

安装node.js 进入node官网进行下载&#xff0c;然后一直下一步。 测试是否安装成功&#xff1a; 命令提示窗下执行&#xff1a;npm -v 若出现版本号&#xff0c;则安装成功。 安装npm源&#xff1a; npm config set registry http://registry.npm.taobao.org 查看&#xff1a;…

C/C++网络编程笔记

https://www.bilibili.com/video/BV11Z4y157RY/?vd_sourced0030c72c95e04a14c5614c1c0e6159b这个视频里面通过简单的例子&#xff0c;讲了socket&#xff0c;对于小白来说还比较友好&#xff0c;我这里做个笔记。让网络通信跑起来我只有本科时候学过一点点C基础&#xff0c;但…

taobao.top.secret.bill.detail( 服务商的商家解密账单详情查询 )

&#xffe5;免费必须用户授权 服务商的商家解密账单详情查询&#xff0c;仅对90天内的账单提供SLA保障。 公共参数 请求地址: HTTP地址 http://gw.api.taobao.com/router/rest 公共请求参数: 公共响应参数: 请求参数 响应参数 点击获取key和secret 请求示例 TaobaoClient…

【LVGL】学习笔记--(3)界面切换以及显示优化

一 界面切换利用lvgl框架绘制GUI免不了需要实现多个页面的切换&#xff0c;毕竟把所有功能和接口都放在一页上有些不太优雅&#xff0c;而且对于嵌入式硬件的小屏幕也有些过于困难。因此这里就需要实现多个页面&#xff08;或者说lvgl里的screen&#xff09;及其互相切换。实现…

初识机器学习

监督学习与无监督学习supervised learning&#xff1a;监督学习&#xff0c;给出的训练集中有输入也有输出&#xff08;标签&#xff09;&#xff08;也可以说既有特征又有目标&#xff09;&#xff0c;在此基础上让计算机进行学习。学习后通过测试集测试给相应的事物打上标签。…

聚观早报|知名品牌3月暂别中国市场;金山办公22年营收38.85亿元

今日要闻&#xff1a;知名品牌3月31日起暂别中国市场&#xff1b;英特尔中国开源技术委员会宣布成立&#xff1b;金山办公2022年营收38.85亿元&#xff1b;美国推特公司进行第八轮裁员&#xff1b;Meta 官宣深入 AI 大战&#xff01; 知名品牌3月31日起暂别中国市场 近日&#…

中级嵌入式系统设计师2015下半年上午试题及答案解析

中级嵌入式系统设计师2015下半年上午试题 单项选择题 1、CPU是在______结束时响应DMA请求的。 A.一条指令执行 B.一段程序 C.一个时钟周期 D.一个总线周期 2、虚拟存储体系由______两级存储器构成。 A.主存-辅存 B.寄存器-Cache C.寄存器-主存

CHAPTER 2 CentOS的日志系统(日志工具)

日志工具2.1 rsyslogd(syslogd)2.1.1 介绍2.1.2 语法2.1.3 配置文件syslog.conf2.1.4 syslog.conf的配置规则2.1.5 示例2.2 logrotate2.2.1 介绍2.2.2 配置文件2.2.3 示例一2.2.4 示例二2.3 dmesg2.3.1 命令简介2.3.2 使用示例2.4 关于重启/死机的日志2.4.1 last2.4.2 日志查看…