[PyTorch][chapter 57][WGAN-GP 代码实现]

news2025/1/19 23:25:21

前言:

 下图为WGAN 的效果图:

  绿色为真实数据的分布: 8个高斯分布

  红色: 为随机产生的数据分布,跟真实分布基本一致

WGAN-GP:

1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 损失函数 增加了penalty,使用Adam

 Wasserstein GAN
1 判别器D: 最后一层去掉sigmoid
2 生成器G 和判别器D: loss不取log
3 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行
 


一  简介

    1.1 模型结构

 1.2 伪代码

      


二  wgan.py

 主要变化:

    Generator 中 去掉了之前的logit 函数

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023

@author: chengxf2
"""

import torch
from   torch import nn



#生成器模型
h_dim = 400
class Generator(nn.Module):
    
    def __init__(self):
        
        super(Generator,self).__init__()
        # z: [batch,input_features]
       
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear( h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            )
        
    def forward(self, z):
        
        output = self.net(z)
        return output
    
#鉴别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        
        super(Discriminator,self).__init__()
        
        hDim=400
        # x: [batch,input_features]
        self.net = nn.Sequential(
            nn.Linear(2, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, 1),
            )
        
    def forward(self, x):
        
        #x:[batch,1]
        output = self.net(x)
        
        out = output.view(-1)
        return out
    




三 main.py

  主要变化:

    损失函数中增加了gradient_penalty

# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023

@author: chengxf2
"""


import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt
from torch import autograd


h_dim =400
batchsz = 256
viz = visdom.Visdom()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def weights_init(net):
   if isinstance(net, nn.Linear):
         # net.weight.data.normal_(0.0, 0.02)
         nn.init.kaiming_normal_(net.weight)
         net.bias.data.fill_(0)

def data_generator():
    """
    8- gaussian destribution

    Returns
    -------
    None.

    """
    scale = 2
    a = np.sqrt(2.0)
    centers =[
         (1,0),
         (-1,0),
         (0,1),
         (0,-1),
         (1/a,1/a),
         (1/a,-1/a),
         (-1/a, 1/a),
         (-1/a,-1/a)
        ]
    
    centers = [(scale*x, scale*y) for x,y in centers]
    
    while True:
        
         dataset =[]
         
         for i in range(batchsz):
             
             point = np.random.randn(2)*0.02
             center = random.choice(centers)
             point[0] += center[0]
             point[1] += center[1]
             dataset.append(point)
         dataset = np.array(dataset).astype(np.float32)
         dataset /=a
         #生成器函数是一个特殊的函数,可以返回一个迭代器
         yield dataset


def generate_image(D, G, xr, epoch):      #xr表示真实的sample
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))             # (16384, 2)
    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    N = len(x)
    # draw contour
    with torch.no_grad():
        points = torch.Tensor(points)      # [16384, 2]
        disc_map = D(points).cpu().numpy() # [16384]
   
    plt.contour(x, y, disc_map.reshape((N, N)).transpose())
    #plt.clabel(cs, inline=1, fontsize=10)
    plt.colorbar()


    # draw samples
    with torch.no_grad():
        z = torch.randn(batchsz, 2)                 # [b, 2]
        samples = G(z).cpu().numpy()                # [b, 2]
    plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    

def gradient_penalty(D, xr,xf):

    #[b,1]
    t =  torch.rand(batchsz, 1).to(device)       
    #[b,1]=>[b,2]  保证每个sample t 相同
    t =  t.expand_as(xr)
    
    #sample penalty interpoation [b,2]
    mid = t*xr +(1-t)*xf
    mid.requires_grad_()
    
    pred = D(mid) #[256]
   
    '''
    grad_outputs:   如果outputs 是向量,则此参数必须写
    retain_graph:  True 则保留计算图, False则释放计算图
    create_graph: 若要计算高阶导数,则必须选为True
    allow_unused: 允许输入变量不进入计算
    '''
    grads = autograd.grad(outputs= pred, inputs = mid,
                      grad_outputs= torch.ones_like(pred),
                      create_graph=True,
                      retain_graph=True,
                      only_inputs=True)[0]
    
    gp = torch.pow(grads.norm(2, dim=1)-1,2).mean()
    
    return gp
    
    
    
    
    
    
         
def main():
  
    lambd = 0.2 #超参数
    maxIter = 1000
    torch.manual_seed(10)
    np.random.seed(10)
    data_iter  = data_generator()
    
   
    G = Generator().to(device)
    D = Discriminator().to(device)
    G.apply(weights_init)
    D.apply(weights_init)
    optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))
    K = 5
 
    

    
   
    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))

    for epoch in range(maxIter):
        
        #1: train Discrimator fistly
        for k in range(K):
            
            #1.1: train on real data
            xr = next(data_iter)
            xr = torch.from_numpy(xr).to(device)
            predr = D(xr)
            
       
            #max(predr) == min(-predr)
            lossr = -predr.mean()
            
            
            #1.2: train on fake data
            z = torch.randn(batchsz,2).to(device) #[b,2] 随机产生的噪声
            xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()
            predf =D(xf)
            lossf = predf.mean()
            
            #1.3 gradient_penalty
            gp = gradient_penalty(D, xr,xf.detach())
            
            #aggregate all
            loss_D = lossr + lossf +lambd*gp
            
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            #print("\n Discriminator 训练结束 ",loss_D.item())
        
        # 2 train  Generator
        
        #2.1 train on fake data
        z = torch.randn(batchsz, 2).to(device)
        xf = G(z)
        predf =D(xf) #期望最大
        loss_G= -predf.mean()
        
        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()
        
        if epoch %100 ==0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
            generate_image(D, G, xr, epoch)
            print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())
         
        
 

    
    
    

if __name__ == "__main__":
    
    main()
         
    

参考:

课时130 WGAN-GP实战_哔哩哔哩_bilibili

WGAN基本原理及Pytorch实现WGAN-CSDN博客

CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1069817.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring AOP 详解及@Trasactional

Spring AOP 详解 AOP基础 AOP: Aspect Oriented Program, 面向切面编程。解耦(组织结构调整)、增强(扩展)。 AOP术语 术语 说明 Aspect(切面) 横切于系统的连接点实现特定功能的类 JoinPoint&#xf…

Axios、SASS学习笔记

目录 前言 一、Axios基础认识 1、简介 2、相关文档 3、基本配置 4、基础快捷使用 二、Axios封装 1、公共配置文件 2、细化每个接口的配置 3、使用并发送请求 三、SASS 1、简介 2、相关文档 3、使用前奏 4、使用变量 5、嵌套规则 6、父级选择器标识 & 前言…

小谈设计模式(10)—原型模式

小谈设计模式(10)—原型模式 专栏介绍专栏地址专栏介绍 原型模式角色分类抽象原型(Prototype)具体原型(Concrete Prototype)客户端(Client)原型管理器(Prototype Manager…

《C++ Primer》第4章 表达式(二)

参考资料: 《C Primer》第5版《C Primer 习题集》第5版 4.6 成员访问运算符(P133) 点运算符和箭头运算符都可用于访问成员,ptr->mem 等价于 (*ptr).mem 。箭头作用于指针类型对象,结果为左值;点运算符…

【Mysql】 blob 转text

有个数据表字段存储的字段类型是blob,想查看字段内容。 blob是二进制的无法直接查看怎么办? 写sql,blob 转text SELECT CONVERT(content USING utf8) FROM article_content ; 我想把原来content字段完全转成text 新建 text 类型字段conten…

uniapp 在uni.scss 根据@mixin定义方法 、通过@include全局使用

在官方文档中提及到uni.scss中变量的使用,而我想定义方法,这样写css样式更方便 一、官方文档的介绍 根据官方文档我知道,在这面定义的变量全局都可使用。接下来我要在这里定义方法。 二、在uni.scss文件中定义方法 我在uni.scss文件中定义了…

三、浏览器缓存动如何使用(Expires、 cache-control、Etag、last-modified)----哪些文件需要强缓存,哪些文件需要协商缓存

参考链接1:彻底弄懂强缓存与协商缓存 参考链接2:浏览器缓存 参考链接3:扼杀 304,Cache-Control: immutable 如何搭建 express,或者node服务 ### 如何搭建 express,npm install express --save### expre…

[C++基础]-多态

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正。 本期学习目标&am…

智能AI系统源码ChatGPT系统源码+详细搭建部署教程+AI绘画系统+已支持OpenAI GPT全模型+国内AI全模型

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统,支持OpenAI GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Chat…

NFT Insider#110:The Sandbox与TB Media Global合作,YGG Web3游戏峰会阵容揭晓

引言:NFT Insider由NFT收藏组织WHALE Members、BeepCrypto出品,浓缩每周NFT新闻,为大家带来关于NFT最全面、最新鲜、最有价值的讯息。每期周报将从NFT市场数据,艺术新闻类,游戏新闻类,虚拟世界类&#xff0…

双字单字拆分/合并操作(博途SCL源代码)

博途PLC的位、字节拆分和合并操作还可以参考下面的文章链接: 博途PLC 位/字/字节 Bit/ Word/Byte拆分与合并_博途的bit-CSDN博客有时候我们需要将分散分布的开关量信号组合为一个整体比如一个字节再完成发送,或者一些报警联锁控制,组合为一个字方便触摸屏报警记录等,下面我…

day11_oop_面向对象基础

零、今日内容 一、作业 二、面向对象 一、作业 题目26 设计方法,在一个数组中,返回所有的 指定数据的下标 例如, 这个数组[1,2,8,4,5,7,8,7,8,9],找到其中元素8的下标[2,6,8]public static void main(String[] args) {int[] arr {1,2,8,4,5,7,8,7,8,9};int[] ind…

云计算安全的新挑战:零信任架构的应用

文章目录 云计算的安全挑战什么是零信任架构?零信任架构的应用1. 多因素身份验证(MFA)2. 访问控制和策略3. 安全信息和事件管理(SIEM)4. 安全的应用程序开发 零信任架构的未来 🎉欢迎来到云计算技术应用专栏…

vue-3

一、文章内容概括 1.生命周期 生命周期介绍生命周期的四个阶段生命周期钩子声明周期案例 2.工程化开发入门 工程化开发和脚手架项目运行流程组件化组件注册 二、Vue生命周期 思考:什么时候可以发送初始化渲染请求?(越早越好&#xff09…

二叉搜索树的初步认识

二叉搜索树与常见的查找算法有什么区别? 首先,如果有同学不知道有哪些查找算法可以查看:常见查找算法_加瓦不加班的博客-CSDN博客 如果还有一些不了解的,请查看加瓦不加班_数据结构,链表,递归-CSDN博客 接下来,我们来…

玄子Share- IDEA 2023 SpringBoot 热部署

玄子Share- IDEA 2023 SpringBoot 热部署 修改 IDEA 部署设置 IDEA 勾选如下选项 新建 SpringBoot 项目 项目构建慢的将 Spring Initializr 服务器 URL 改为阿里云:https://start.aliyun.com/ 在这里直接勾选Spring Boot Devtools插件即可 测试 切出 IDEA 项目文…

「专题速递」AR协作、智能NPC、数字人的应用与未来

元宇宙是一个融合了虚拟现实、增强现实、人工智能和云计算等技术的综合概念。它旨在创造一个高度沉浸式的虚拟环境,允许用户在其中交互、创造和共享内容。在元宇宙中,人们可以建立虚拟身份、参与虚拟社交,并享受无限的虚拟体验。 作为互联网大…

2023年H1汽车社媒营销趋势报告

2023上半年车市“内卷”,汽车营销玩法越来越多样,品牌的矩阵式营销也成为头部车企重点营销战略。 据乘联会预测,2023下半年车市将受到二季度价格战的透支,需要一定修复期。在整体形势较不利的情况下,车企如何破局实现销…

通过IP地址如何计算相关地址

以IP地址为1921681005 子网掩码是2552552550为例。算出网络地址、广播地址、地址范围、主机数。 分步骤计算 1) 将IP地址和子网掩码换算为二进制,子网掩码连续全1的是网络地址,后面的是主机地址。 虚线前为网络地址,虚线后为主机…

基于Java的个性化旅游攻略系统设计与实现(源码+lw+ppt+部署文档+视频讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序(小蔡coding)有保障的售后福利 代码参考源码获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作…