[PyTorch][chapter 56][GAN 代码实现]

news2024/11/17 23:48:43

前言:

      整个工程分为两个文件:

      gan.py:  网络模型搭建

      main.py:  数据集生成,模型训练


目录:

  1.      GAN 网络结构
  2.      gan.py
  3.       main.py

一  GAN 网络结构

       

      1.1 训练D

            V(G,D)= E_{x \sim p_r}[log D(x)]+E_{x \sim p_g}[log (1-D(x)]

            最大化V

     1.2  训练G

            固定G, 最小化

             E_{x \sim p_z}[log(1-D(x)]


二 gan.py

   功能:

        实现 鉴别器D

        实现 生成器G

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023

@author: chengxf2
"""

import torch
from   torch import nn



#生成器模型
class Generator(nn.Module):
    
    def __init__(self):
        
        super(Generator,self).__init__()
        # z: [batch,input_features]
        h_dim = 400
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear( h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            )
        
    def forward(self, z):
        
        output = self.net(z)
        return output
    
#鉴别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        
        super(Discriminator,self).__init__()
        
        hDim=400
        # x: [batch,input_features]
        self.net = nn.Sequential(
            nn.Linear(2, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, 1),
            nn.Sigmoid()
            )
        
    def forward(self, x):
        
        #x:[batch,1]
        output = self.net(x)
        
        out = output.view(-1)
        return out
    




三  main.py

 功能:

      生成数据

       训练网络

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023

@author: chengxf2
"""


import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt


h_dim =400
batchSize = 256
viz = visdom.Visdom()

#viz = visdom.Visdom()


def weights_init(net):
   if isinstance(net, nn.Linear):
         # net.weight.data.normal_(0.0, 0.02)
         nn.init.kaiming_normal_(net.weight)
         net.bias.data.fill_(0)

def data_generator():
    """
    8- gaussian destribution

    Returns
    -------
    None.

    """
    scale = 2
    a = np.sqrt(2.0)
    centers =[
         (1,0),
         (-1,0),
         (0,1),
         (0,-1),
         (1/a,1/a),
         (1/a,-1/a),
         (-1/a, 1/a),
         (-1/a,-1/a)
        ]
    
    centers = [(scale*x, scale*y) for x,y in centers]
    
    while True:
        
         dataset =[]
         
         for i in range(batchSize):
             
             point = np.random.randn(2)*0.02
             center = random.choice(centers)
             point[0] += center[0]
             point[1] += center[1]
             dataset.append(point)
         dataset = np.array(dataset).astype(np.float32)
         dataset /=a
         #生成器函数是一个特殊的函数,可以返回一个迭代器
         yield dataset


def generate_image(D, G, xr, epoch):      #xr表示真实的sample
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))             # (16384, 2)
    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    N = len(x)
    # draw contour
    with torch.no_grad():
        points = torch.Tensor(points)      # [16384, 2]
        disc_map = D(points).cpu().numpy() # [16384]
   
    plt.contour(x, y, disc_map.reshape((N, N)).transpose())
    #plt.clabel(cs, inline=1, fontsize=10)
    plt.colorbar()


    # draw samples
    with torch.no_grad():
        z = torch.randn(batchSize, 2)                 # [b, 2]
        samples = G(z).cpu().numpy()                # [b, 2]
    plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    
         
         
def main():
  
    maxIter = 1000
    torch.manual_seed(10)
    np.random.seed(10)
    data_iter  = data_generator()
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    G = Generator().to(device)
    D = Discriminator().to(device)
    G.apply(weights_init)
    D.apply(weights_init)
    optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))
    K = 5
 
    

    
   
    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))

    for epoch in range(maxIter):
        
        #1: train Discrimator fistly
        for k in range(K):
            
            #1.1: train on real data
            xr = next(data_iter)
            xr = torch.from_numpy(xr).to(device)
            predr = D(xr)
            
            vr = torch.log(predr)
            #max(predr) == min(-predr)
            lossr = vr.mean()
            
            #1.2: train on fake data
            z = torch.randn(batchSize,2).to(device) #[b,2] 随机产生的噪声
            xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()
            predf =D(xf) #min predf
            
            vf = torch.log(1e-4+1.0-predf)
            lossf = vf.mean()
            loss_D =-(lossr+lossf)
            
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            #print("\n Discriminator 训练结束 ",loss_D.item())
        
        # 2 train  Generator,max V(G,D)
        
        #2.1 train on fake data
        z = torch.randn(batchSize, 2)
        xf = G(z)
        predf =D(xf) #max predf
        
        vf = torch.log(1e-4+1.0-predf)
        loss_G= predf.mean()
        
        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()
        
        if epoch %100 ==0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
            generate_image(D, G, xr, epoch)
            print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())
         
        
 

    
    
    

if __name__ == "__main__":
    
    main()
         
    


三 训练效果

   里面的损失函数按照最早的论文里面的,跟其它版本有所区别

 效果:

      生成器G 训练的loss 最后稳定在一个固定值,无法更新生成器

      鉴别器: 因为生成器很弱,很容易鉴别出真实数据 和 fake 数据,导致loss 也迅速降低为0

      实际生成效果:

                 生成器生成出来的数据红色部分,和真实的数据分布绿色 有较大差距。

    生成器很弱。

 

   

参考:

课时127 GAN实战-GD实现_哔哩哔哩_bilibili

https://www.cnblogs.com/cxq1126/p/13538409.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1066606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3+elementPlus el-input的type=“number“时去除右边的上下箭头

改成 代码如下 <script lang"ts" setup> import {ref} from vue const inputBtn ref() </script> <template><el-input type"number" v-model"inputBtn" style"width: 80px;" class"no_number">…

超长表单分页校验,下一页和上一页功能

父组件(最外层) <template><xx-layout title"练习"><divslot"content"class"hierarchy-tag-main"><el-steps:space"200":active"currentComponentIndex 1"align-centerstyle"margin-bottom: 30…

Flutter横屏实践

1、Flutter设置横屏 // 强制横屏 SystemChrome.setPreferredOrientations([DeviceOrientation.landscapeLeft,DeviceOrientation.landscapeRight ]); // 强制竖屏 SystemChrome.setPreferredOrientations([DeviceOrientation.portraitUp, DeviceOrientation.portraitDown]);另…

【虹科分享】什么是Redis数据集成(RDI)?

大量的应用程序、日益增长的用户规模、不断扩展的技术需求&#xff0c;以及对即时响应的持续追求。想想这些是否正是你在经历的。也许你尝试过自己构建工具来应对这些需求&#xff0c;但是大量的编码和集成工作使你焦头烂额。那你是否知道&#xff0c;有这样一个工具可以帮助你…

数据结构-图-最小生成树问题

最小生成树 并查集定义举例说明查找某个元素属于哪个集合代码实现路径压缩 Kruskal算法原理代码实现 Prim算法原理代码实现 并查集 定义 &#x1f680;在一些应用问题中&#xff0c;需要将n个不同的元素分成一些不相交的集合。开始时&#xff0c;每个元素自成一个单元素集合&…

小商品公众号微信店铺搭建的作用是什么

小商品顾名思义就是价格低、需求广且数量多的日用产品&#xff0c;覆盖人群非常广&#xff0c;无论线上还是线下总能找到目标客户&#xff0c;铅笔、削皮刀、晾衣架等各式产品琳琅满目&#xff0c;不少商家也是热衷于小商品的售卖。 从整体来看&#xff0c;小商品商家也可线上…

tomcat整体架构

Tomcat介绍 Tomcat是Apache Software Foundation&#xff08;Apache软件基金会&#xff09;开发的一款开源的Java Servlet 容器。它是一种Web服务器&#xff0c;用于在服务器端运行Java Servlet和JavaServer Pages (JSP)技术。它可 以为Java Web应用程序提供运行环境&#x…

刚入职字节外包一个月,我却离职了...

有一种打工人的羡慕&#xff0c;叫做“大厂”。 真是年少不知大厂香&#xff0c;错把青春插稻秧。 但是&#xff0c;在深圳有一群比大厂员工更庞大的群体&#xff0c;他们顶着大厂的“名”&#xff0c;做着大厂的工作&#xff0c;还可以享受大厂的伙食&#xff0c;却没有大厂…

新手选MT4还是MT5,anzo capital昂首资本建议选择MT4,一个原因

在交易中就订单执行策略而言&#xff0c;MT4和MT5哪个更好&#xff0c;相信很多交易者和&#xff0c;anzo capital昂首资本一样很难做出判断。在MT5中&#xff0c;虽然开发人员对发送订单的流程进行了额外的复杂化&#xff0c;同时MT5在订单执行政策方面的优势在于其能够调整全…

告警繁杂迷人眼,多源分析见月明

随着数字化浪潮的蓬勃兴起&#xff0c;网络安全问题日趋凸显&#xff0c;面对指数级增长的威胁和告警&#xff0c;传统的安全防御往往力不从心。网内业务逻辑不规范、安全设备技术不成熟都会导致安全设备触发告警。如何在海量众多安全告警中识别出真正的网络安全攻击事件成为安…

Vue3项目使用Stimulsoft.Reports.js【项目实战】

Vue3项目使用Stimulsoft.Reports.js【项目实战】 相关阅读&#xff1a;vue-cli使用stimulsoft.reports.js&#xff08;保姆级教程&#xff09;_stimulsoft vue-CSDN博客 前言 在BS的项目中我们时常会用到报表打印、标签打印、单据打印&#xff0c;可是BS的通用打印解决方案又…

【JavaEE初阶】 多线程(初阶)——壹

文章目录 &#x1f332;线程的概念&#x1f6a9;线程是什么&#x1f6a9;为啥要有线程&#x1f6a9;进程和线程的区别&#x1f6a9;Java 的线程 和 操作系统线程 的关系 &#x1f60e;第一个多线程程序&#x1f6a9;使用 jconsole 命令观察线程 &#x1f38d;创建线程&#x1f…

字段位置顺序对值的影响

Unity中验证AB加载场景时报错&#xff1a; Cannot load scene: Invalid scene name (empty string) and invalid build index -1 报错原因是因为把字段放在了Start函数后面(图一)改成(图二)就好了。图一中协程使用的sceneBName字段值为null。 图一&#xff1a; 图二&#xff1a…

【C++】List -- 详解

一、list的介绍及使用 https://cplusplus.com/reference/list/list/?kwlist list 是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。 list 的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&…

Windows10点击开始菜单没反应的四种解决方法

在Windows10电脑中&#xff0c;用户点击开始菜单却出现了没反映的情况&#xff0c;这样用户就无法通过开始菜单来展开操作哦&#xff0c;会给用户的正常操作带来一定程序的影响&#xff0c;下面小编给大家带来四种简单的解决方法&#xff0c;帮助大家轻松恢复Windows10电脑开始…

Selenium 高级定位 CSS

一、CSS选择器概念 CSS拥有自己的语法规则和表达式 CSS通常分为相对定位和绝对定位 CSS常和XPATH一起用于UI自动化测试 二、CSS相对定位使用场景 支持web场景支持app端的webview 三、CSS语法实战 3.1、CSS相对定位的优点 可维护性强语法简洁可以解决各种复杂的定位场景 # …

ARMv7-A 那些事 - 6.常用汇编指令

By: Ailson Jack Date: 2023.10.07 个人博客&#xff1a;http://www.only2fire.com/ 本文在我博客的地址是&#xff1a;http://www.only2fire.com/archives/158.html&#xff0c;排版更好&#xff0c;便于学习&#xff0c;也可以去我博客逛逛&#xff0c;兴许有你想要的内容呢。…

【二叉树练习题】

欢迎来到我的&#xff1a;世界 希望作者的文章对你有所帮助&#xff0c;有不足的地方还请指正&#xff0c;大家一起学习交流 ! 目录 前言初阶题二叉树的节点个数二叉树的叶子节点个数二叉树第k层节点个数二叉树查找值为x的节点 进阶题完全二叉树的节点个数翻转二叉树检验两个树…

修炼k8s+flink+hdfs+dlink(一:安装dlink)

一&#xff1a;mysql初始化。 mysql -uroot -p123456 create database dinky; grant all privileges on dinky.* to dinky% identified by dinky with grant option; flush privileges;二&#xff1a;上传dinky。 上传至目录/opt/app/dlink tar -zxvf dlink-release-0.7.4.t…

医学访问学者面试技巧

医学访问学者面试是一个非常重要的环节&#xff0c;它决定了你是否能够获得这个宝贵的机会去国外的大学或研究机构学习和研究。在这篇文章中&#xff0c;知识人网小编将分享一些关于医学访问学者面试的技巧&#xff0c;帮助你在面试中表现出色。 1. 准备充分 在参加医学访问学…