【PyTorch][chapter 20][李宏毅深度学习]【无监督学习][ GAN]【实战】

news2024/9/19 9:19:15

前言

 本篇主要是结合手写数字例子,结合PyTorch 介绍一下Gan 实战

第一轮训练效果

第20轮训练效果,已经可以生成数字了

68 轮


目录: 

  1.   谷歌云服务器(Google Colab)
  2.   整体训练流程
  3.   Python 代码

一  谷歌云服务器(Google Colab)

     个人用的一直是联想小新笔记本,虽然非常稳定方便。但是现在跑深度学习,性能确实有点跟不上. 

   1.1    打开谷歌云服务器(Google Colab)

      https://colab.research.google.com/

    1. 2  新建笔记

                 

1

 1.4  选择T4GPU 

1.5  点击运行按钮

可以看到当前硬件的情况

     


二  整体训练流程


三    PyTorch 例子

# -*- coding: utf-8 -*-
"""
Created on Fri Mar  1 13:27:49 2024

@author: chengxf2
"""
import torch.optim as optim #优化器
import numpy as np 
import matplotlib.pyplot  as plt
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn

#第一步加载手写数字集
def loadData():
  
    #同时归一化数据集(-1,1)
    style = transforms.Compose([
        transforms.ToTensor(),   #0-1 归一化0-1, channel,height,width
        transforms.Normalize(mean=0.5, std=0.5) #变成了-1,1 
        ]
        )
    trainData = torchvision.datasets.MNIST('data',
                                           train=True,
                                           transform=style,
                                           download=True)
    
    
    
    dataloader = torch.utils.data.DataLoader(trainData,
                                             batch_size= 16,
                                             shuffle=True)
    
    imgs,_ = next(iter(dataloader))
    #torch.Size([64, 1, 28, 28])
    print("\n imgs shape ",imgs.shape)
    
    return dataloader
    

class Generator(nn.Module):
     '''
      定义生成器
      输入:
          z 随机噪声[batch, input_size]
     输出:
         x: 图片 [batch, height, width, channel]
     '''
     def __init__(self,input_size):
          
          super(Generator,self).__init__()
          self.net = nn.Sequential(
              nn.Linear(in_features = input_size , out_features =256),
              nn.ReLU(),
              nn.Linear(in_features = 256 , out_features =512),
              nn.ReLU(),
              nn.Linear(in_features = 512 , out_features =28*28),
              nn.Tanh()
              )
          
     def forward(self, z):
          
          # z 随机输入[batch, dim]
          x = self.net(z)
          #[batch, height, width, channel]
          #print(x.shape)
          x = x.view(-1,28,28,1)
          return x
          
class Discriminator(nn.Module):
     '''
      定义鉴别器
      输入:
          x: 图片 [batch, height, width, channel]
     输出:
         y:  二分类图片的概率: BCELoss 计算交叉熵损失
     '''
     def __init__(self):
          
          super(Discriminator,self).__init__()
          #开始的维度和终止的维度,默认值分别是1和-1
          self.flatten = nn.Flatten()
          self.net = nn.Sequential(
              nn.Linear(in_features = 28*28 , out_features =512),
              nn.LeakyReLU(), #负值的时候保留梯度信息
              nn.Linear(in_features = 512 , out_features =256),
              nn.LeakyReLU(),
              nn.Linear(in_features = 256 , out_features =1),
              nn.Sigmoid()
              )
          
     def forward(self, x):
       
         x = self.flatten(x)
         #print(x.shape)
         out =self.net(x)
          
         return out
     
def gen_img_plot(model, epoch, test_input):
    
    out = model(test_input).detach().cpu()
    
    out = out.numpy()
    
    imgs = np.squeeze(out)
    
    fig = plt.figure(figsize=(4,4))
    
    for i in range(out.shape[0]):
        
        plt.subplot(4,4,i+1)
        img = (imgs[i]+1)/2.0#[-1,1]
        plt.imshow(img)
        plt.axis('off')
    plt.show()
    
     
def train():
    
    #1 初始化参数
    device ='cuda' if torch.cuda.is_available() else 'cpu'
    #2 加载训练数据
    dataloader = loadData()
    test_input  = torch.randn(16,100,device=device)
    
    #3 超参数
    maxIter = 20 #最大训练次数
    input_size = 100
    batchNum = 16
    input_size =100
    
    #4 初始化模型
    gen = Generator(100).to(device)
    dis = Discriminator().to(device)

    
    #5 优化器,损失函数
    d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
    g_optim = torch.optim.Adam(gen.parameters(),lr=1e-4)
    loss_fn = torch.nn.BCELoss()
    
    #6 loss 变化列表
    D_loss =[]
    G_loss= []
    
    
   
    
    for epoch in range(0,maxIter):
        
        d_epoch_loss = 0.0
        g_epoch_loss  =0.0
        #count = len(dataloader)
        
        for step ,(realImgs, _) in enumerate(dataloader):
            
            realImgs = realImgs.to(device)
            random_noise = torch.randn(batchNum, input_size).to(device)
            
            
            
            #先训练判别器
            d_optim.zero_grad()
            real_output = dis(realImgs)
            d_real_loss = loss_fn(real_output, torch.ones_like(real_output))
            d_real_loss.backward()
            
            #不要训练生成器,所以要生成器detach
            fake_img = gen(random_noise)
            fake_output = dis(fake_img.detach())
            d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
            d_fake_loss.backward()
            d_loss = d_real_loss+d_fake_loss
            d_optim.step()
            
            #优化生成器
            g_optim.zero_grad()
            fake_output = dis(fake_img.detach())
            g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
            g_loss.backward()
            g_optim.step()
            
            with torch.no_grad():
                d_epoch_loss+= d_loss
                g_epoch_loss+= g_loss
        count = 16       
        with torch.no_grad():
                
                d_epoch_loss/=count
                g_epoch_loss/=count
                D_loss.append(d_epoch_loss)
                G_loss.append(g_epoch_loss)
                gen_img_plot(gen, epoch, test_input)
                print("Epoch: ",epoch)
    print("-----finised-----")
        
                
                
    
    
    
if __name__ == "__main__":
 
    
    train()
  
   
    

参考:

10.完整课程简介_哔哩哔哩_bilibili

理论【PyTorch][chapter 19][李宏毅深度学习]【无监督学习][ GAN]【理论】-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1482130.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

List<Object>集合对象属性拷贝工具类

目录 问题现象: 问题分析: 解决方法: 问题现象: 最近在项目中经常会使用到BeanUtils工具类来作对象的属性字段拷贝,但如果应用到List集合的话就需要遍历去操作了,如下: 打印结果: …

分类问题经典算法 | 二分类问题 | Logistic回归:公式推导

目录 一. Logistic回归的思想1. 分类任务思想2. Logistic回归思想 二. Logistic回归算法:线性可分推导 一. Logistic回归的思想 1. 分类任务思想 分类问题通常可以分为二分类,多分类任务;而对于不同的分类任务,训练的主要目标是…

小乌龟操作Git

1、选择小乌龟作为git客户端 最近使用idea来操作git的时候频频出现问题,要么是提交代码的时候少了某些文件,导致克隆下来无法运行,要么是提交速度太慢。 反正是在idea中操作git体验非常不好,所以决定来换一种方式来操作git。从网…

Java类加载器 和 双亲委派【详解】

一.类加载器: 由JDK提供的,用于加载一些资源文件到JVM内存里的一项技术。主要是加载class文件到内存,也可以加载一些资源文件。 2.JDK提供了三个类加载器: BootstrapClassLoader:引导类加载器, 是c语言编写…

[SUCTF 2019]EasyWeb --不会编程的崽

个人认为&#xff0c;这题还算有些东西。先来看源码 <?php function get_the_flag(){// webadmin will remove your upload file every 20 min!!!! $userdir "upload/tmp_".md5($_SERVER[REMOTE_ADDR]);if(!file_exists($userdir)){mkdir($userdir);}if(!empty…

Node服务器性能分析调优debug

node其实自带提供了性能分析工具&#xff1a;node --profile 但是它分析起来并不是很好用&#xff0c;于是chrome提供了另一种分析和debug工具&#xff1a;Chrome devtool 使用这个工具我们可以分析自己的Node项目调用堆栈里耗时较长的任务&#xff0c;对应去做缓存或者异步调用…

Day10:基础入门-HTTP数据包Postman构造请求方法请求头修改状态码判断

目录 数据-方法&头部&状态码 案例-文件探针 案例-登录爆破 工具-Postman自构造使用 思维导图 章节知识点&#xff1a; 应用架构&#xff1a;Web/APP/云应用/三方服务/负载均衡等 安全产品&#xff1a;CDN/WAF/IDS/IPS/蜜罐/防火墙/杀毒等 渗透命令&#xff1a;文件…

Facebook广告账户被封的可能原因及应如何避免?

Facebook作为全球最大的社交平台&#xff0c;是众多出海企业和广告主的推广渠道。但不少广告主在使用Facebook广告账户推广时&#xff0c;因各种原因导致账户被封&#xff0c;不仅影响了业务进程&#xff0c;更是给广告主带来了经济损失。 账户被封邮件如下&#xff1a; 以下是…

Dell R730 2U服务器实践2:VMWare ESXi安装

缘起 刚到手边的一台Dell R730是三块硬盘raid0 &#xff0c;把我惊出一身冷汗&#xff0c;准备把它们改组成raid1 或者raid5 。 但是舍不得里面的ESXi 8 &#xff0c;寻找能否把raid0改成raid1 还不掉WSXi的方法&#xff0c;很遗憾没有找到。那样只能重装ESXi了。 ESXi软件下…

Python 画 箱线图

Python 画 箱线图 flyfish 箱线图 其他名字 盒须图 / 箱形图 横向用正态分布看 垂直看 pandas画 import pandas as pdimport seaborn as sns import matplotlib.pyplot as plt import pandas as pddf pd.read_csv(sh300.csv) print("原始数据") print(df.he…

【JVM】聊聊常见的JVM排查工具

JDK工具包 jps 虚拟机进程状况工具 jps是虚拟机进程状况工具&#xff0c;列出正在运行的虚拟机进程&#xff0c;使用 Windows 的任务管理器或 UNIX 的 ps 命令也可以查询&#xff0c;但如果同时启动多个进程&#xff0c;必须依赖 jps。jps -l 显示类名 jps :列出Java程序进程…

循环简介和基本运算符

根据C Primer Plus第五章进行学习 文章目录 循环简介基本运算符 1.赋值运算符&#xff1a;2.加法运算符&#xff1a;3.减法运算符&#xff1a;-2.乘法运算符&#xff1a;*总结 1.循环简介 如下代码可以体现不使用循环的局限性&#xff1a; #include<stdio.h> #define AD…

云天励飞战略投资神州云海,布局机器人市场

日前,AI上市企业云天励飞(688343.SH)完成了对深圳市神州云海智能科技有限公司(以下简称“神州云海”)的B轮战略投资。 公开资料显示,自2015年于深圳创立以来,神州云海始终聚焦人工智能与服务机器人广阔的应用市场,依托自主的核心算法能力,深耕机器人硬件本体研发,整合上下游产…

【机器学习300问】25、常见的模型评估指标有哪些?

模型除了从数据划分的角度来评估&#xff0c;我上一篇文章介绍了数据集划分的角度&#xff1a; 【机器学习300问】24、模型评估的常见方法有哪些&#xff1f;http://t.csdnimg.cn/LRyEt 还可以从一些指标的角度来评估&#xff0c;这篇文章就带大家从两个最经典的任务场景介绍…

微信云开发-- Mac安装 wx-server-sdk依赖

第一次上传部署云函数时&#xff0c;会提示安装依赖wx-server-sdk 一. 判断是否安装wx-server-sdk依赖 先创建一个云函数&#xff0c;然后检查云函数目录。 如果云函数目录下只显示如下图所示三个文件&#xff0c;说明未安装依赖。 如果云函数目录下显示如下图所示四个文件&a…

数电实验之流水灯、序列发生器

最近又用到了数电实验设计的一些操作和设计思想&#xff0c;遂整理之。 广告流水灯 实验内容 用触发器、组合函数器件和门电路设计一个广告流水灯&#xff0c;该流水灯由 8 个 LED 组成&#xff0c;工作时始终为 1 暗 7 亮&#xff0c;且这一个暗灯循环右移。 1) 写出设计过…

吴恩达机器学习笔记:第5周-9 神经网络的学习(Neural Networks: Learning)

目录 9.1 代价函数 9.1 代价函数 首先引入一些便于稍后讨论的新标记方法&#xff1a; 假设神经网络的训练样本有&#x1d45a;个&#xff0c;每个包含一组输入&#x1d465;和一组输出信号&#x1d466;&#xff0c;&#x1d43f;表示神经网络层数&#xff0c;&#x1d446;&…

babylonjs入门-半球光

基于babylonjs封装的一些功能和插件 &#xff0c;希望有更多的小伙伴一起玩babylonjs&#xff1b; 欢迎加群&#xff08;点击群号传送&#xff09;&#xff1a;464146715 官方文档 中文文档 案例传送门 懒得打字 粘贴复制 一气呵成

kubectl 陈述式资源管理方法

目录 陈述式资源管理方法 项目的生命周期 1.创建kubectl create命令 2.发布kubectl expose命令 service的4的基本类型 查看pod网络状态详细信息和 Service暴露的端口 查看关联后端的节点 ​编辑 查看 service 的描述信息 ​编辑在 node01 节点上操作&#xff0c;查看…

3.1 IO进程线程

使用fwrite、fread将一张随意的bmp图片&#xff0c;修改成德国的国旗 #include <stdio.h> #include <string.h> #include <unistd.h> #include <stdlib.h> int main(int argc, const char *argv[]) {FILE* fp fopen("./2.bmp","r&quo…