【人工智能基础】GAN与WGAN实验

news2024/12/23 22:16:12

一、GAN网络概述

GAN:生成对抗网络。GAN网络中存在两个网络:G(Generator,生成网络)和D(Discriminator,判别网络)。

Generator接收一个随机的噪声z,通过这个噪声生成图片,记做G(z)

Discriminator功能是判别一张图片的真实。它的输入是一张图片x,输出D(x)代表x为真实图片的概率,如果为1就代表图片真实,而输出为0,就代表图片不真实。

在GAN网络的训练中,Generator的目标就是尽量生成真实的图片去欺骗Discriminator

Discriminator的目标就是尽量把Generator生成的图片和真实的图片分别开来

二、GAN实验环境准备

除了之前使用过的pytorch-nplnumpy以外,我们还需要安装visdom

pip install visdom

启动visdom

python -m visdom.server

visdom启动成功如下图,会占用8097端口,我们可以通过8097端口访问visdom

visdom启动.png

三、GAN网络实验

环境参数配置

import torch
from torch import nn,optim,autograd
import numpy as np
import visdom
import random

h_dim = 400
batchsz = 512
viz = visdom.Visdom()

生成网络定义

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.net = nn.Sequential(
            # input[b, 2]
            nn.Linear(2,h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            # output[b,2]
        )

    def forward(self, z):
        output = self.net(z)
        return output

判别网络定义

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

数据集生成函数

def data_generator():
    # 生成中心点
    scale = 2
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x,y in centers] 
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * 0.02
            # 随机选取一个中心点
            center = random.choice(centers)
            # 把刚刚随机到的高斯分布点根据center进行移动
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset).astype(np.float32)
        dataset /= 1.414
        yield dataset

可视化函数

将图片生成到visdom

import matplotlib.pyplot as plt
def generate_image(D, G, xr, epoch):
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:,:,0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:,:,1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1,2))

    with torch.no_grad():
        points = torch.Tensor(points).cpu()
        disc_map = D(points).cpu().numpy()
    x = y = np.linspace(-RANGE,RANGE,N_POINTS)
    cs = plt.contour(x,y,disc_map.reshape((len(x), len(y))).transpose())
    plt.clabel(cs, inline=1,fontsize=10)

    with torch.no_grad():
        z = torch.randn(batchsz, 2).cpu()
        samples = G(z).cpu().numpy()
    plt.scatter(xr[:,0],xr[:,1],c='orange',marker='.')
    plt.scatter(samples[:,0], samples[:,1], c='green',marker='+')

    viz.matplot(plt, win='contour',opts=dict(title='p(x):%d'%epoch))

运行函数

def run():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    # print(x.shape)

    # G = Generator().cuda()
    # D = Discriminator().cuda()
    # 无显卡环境
    device = torch.device("cpu")
    G = Generator().cpu()
    print(G)
    D = Discriminator().cpu()
    print(D)

    optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))

    """
    gan核心部分
    """
    for epoch in range(50000):
        # 训练判别网络
        for _ in range(5):
            # 真实数据训练
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cpu()
            predr = D(xr)
            # 放大真实数据
            lossr = -predr.mean()

            # 虚假数据训练
            z = torch.randn(batchsz,2).cpu()
            xf = G(z).detach()
            predf = D(xf)
            # 缩小虚假数据
            lossf = predf.mean()

            loss_D = lossr + lossf

            # 梯度清零
            optim_D.zero_grad()
            # 向后传播
            loss_D.backward()
            optim_D.step()


        # 训练生成网络
        z = torch.randn(batchsz,2).cpu()
        xf = G(z)
        predf = D(xf)
        loss_G = -predf.mean()
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
            print(loss_D.item(), loss_G.item())
            generate_image(D, G, xr, epoch)

执行(GAN的不稳定性)

run()

从结果中可以看到,判别网络的loss一直为0,而生成网络一直得不到更新,生成的数据点远离我们创建的中心点

gan运行.png

四、wgan实验

WGAN主要从损失函数的角度对GAN做了改进,对更新后的权重强制截断到一定范围内

增加一个梯度惩罚函数

def gradient_penalty(D,xr,xf):
    # [b,1]
    t = torch.rand(batchsz, 1).cpu()
    # 扩展为[b, 2]
    t = t.expand_as(xr)
    # 插值
    mid = t * xr + (1 - t) * xf
    # 设置需要的倒数信息
    mid.requires_grad_()

    pred = D(mid)
    grads = autograd.grad(outputs=pred, 
                          inputs=mid,
                          grad_outputs=torch.ones_like(pred),
                          create_graph=True,
                          retain_graph=True,
                          only_inputs=True)[0]
    gp = torch.pow(grads.norm(2, dim=1) - 1, 2).mean()
    return gp

修改运行函数

def run():
    torch.manual_seed(23)
    np.random.seed(23)

    data_iter = data_generator()
    x = next(data_iter)
    # print(x.shape)

    # G = Generator().cuda()
    # D = Discriminator().cuda()
    # 无显卡环境
    device = torch.device("cpu")
    G = Generator().cpu()
    print(G)
    D = Discriminator().cpu()
    print(D)

    optim_G = optim.Adam(G.parameters(), lr = 5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(), lr = 5e-4, betas=(0.5,0.9))

    viz.line([[0,0]],[0],win='loss', opts=dict(title='loss',legend=['D','G']))

    """
    gan核心部分
    """
    for epoch in range(50000):
        # 训练判别网络
        for _ in range(5):
            # 真实数据训练
            xr = next(data_iter)
            xr = torch.from_numpy(xr).cpu()
            predr = D(xr)
            # 放大真实数据
            lossr = -predr.mean()

            # 虚假数据训练
            z = torch.randn(batchsz,2).cpu()
            xf = G(z).detach()
            predf = D(xf)
            # 缩小虚假数据
            lossf = predf.mean()

            # 梯度惩罚值
            gp = gradient_penalty(D,xr,xf.detach())
            loss_D = lossr + lossf + 0.2 * gp
            # 梯度清零
            optim_D.zero_grad()
            # 向后传播
            loss_D.backward()
            optim_D.step()


        # 训练生成网络
        z = torch.randn(batchsz,2).cpu()
        xf = G(z)
        predf = D(xf)
        loss_G = -predf.mean()
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

        if epoch % 100 == 0:
            viz.line([[loss_D.item(),loss_G.item()]], [epoch],win='loss', update='append')
            print(loss_D.item(), loss_G.item())
            generate_image(D, G, xr, epoch)

执行

run()

可以看到在wgan中,生成网络开始学习,生成的数据点也能基本根据高斯分布落在中心点附近

wgan运行.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1646902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

视频改字祝福 豪车装X系统源码uniapp前端小程序源码

视频改字祝福 豪车装X系统源码uniapp前端小程序源码,创意无限!AI视频改字祝福,豪车装X系统源码开源,打造个性化祝 福视频不再难! 想要为你的朋友或家人送上一份特别的祝福,让他们感受到你的真诚与关怀吗&am…

若依前后端分离部署nginx

1、v.sj 2、生产环境修改 3、退出登录修改 4、路由改为hash模式 5、nginx配置 location /gldhtml/ {alias D:/java/tool/nginx-1.19.6/project/jxal/html/; } location /jxal/ {proxy_pass http://localhost:8081/; }

Rust Course学习(编写测试)

如果友友你的计算机上没有安装Rust,可以直接安装:Rust 程序设计语言 (rust-lang.org)https://www.rust-lang.org/zh-CN/ Introduce 介绍 Testing in Rust involves writing code specifically designed to verify that other code works as expected. It…

leetcode-岛屿数量-99

题目要求 思路 1.使用广度优先遍历,将数组中所有为1的元素遍历一遍,遍历过程中使用递归,讲该元素的上下左右四个方向的元素值也置为0 2.统计一共执行过多少次,次数就是岛屿数量 代码实现 class Solution { public:int solve(vec…

AWS宣布推出Amazon Q :针对商业数据和软件开发的生成性AI助手

亚马逊网络服务(AWS)近日宣布推出了一项名为“Amazon Q”的新服务,旨在帮助企业利用生成性人工智能(AI)技术,优化工作流程和提升业务效率。这一创新平台的推出,标志着企业工作方式的又一次重大变…

AIGC-音频生产十大主流模型技术原理及优缺点

音频生成(Audio Generation)指的是利用机器学习和人工智能技术,从文本、语音或其他源自动生成音频的过程。 音频生成行业是AIGC技术主要渗透的领域之一。AI音频生成行业是指利用人工智能技术和算法来生成音频内容的领域。按照输入数据类型不同可以分为:根…

multipass launch失败:launch failed: Remote ““ is unknown or unreachable.

具体问题情况如下: C:\WINDOWS\system32>multipass launch --name my-vm 20.04launch failed: Remote "" is unknown or unreachable.​C:\WINDOWS\system32>multipass lsNo instances found.​C:\WINDOWS\system32>multipass startlaunch fail…

[信息收集]-端口扫描--Nmap

端口号 端口号的概念属于计算机网络的传输层,标识这些不同的应用程序和服务而存在的。通过使用不同的端口号,传输层可以将接收到的数据包准确地传递给目标应用程序。 80:HTTP(超文本传输协议)用于Web浏览器访问网页 …

【论文泛读】如何进行动力学重构? 神经网络自动编码器结合SINDy发现数据背后蕴含的方程

这一篇文章叫做 数据驱动的坐标发现与方程发现算法。 想回答的问题很简单,“如何根据数据写方程”。 想想牛顿的处境,如何根据各种不同物体下落的数据,写出万有引力的数学公式的。这篇文章就是来做这件事的。当然,这篇论文并没有…

一文带你了解多数企业系统都在用的 RBAC 权限管理策略

前言 哈喽你好呀,我是 嘟老板,今天我们来聊聊几乎所有企业系统都离不开的 权限管理,大家平时在做项目开发的时候,有没有留意过权限这块的设计呢?都是怎样实现的呢?如果现在脑子里对于这块儿不够清晰&#…

作为全栈工程师,如何知道package.json中需要的依赖分别需要什么版本去哪里查询?

作为前端工程师,当你需要确定package.json中依赖的具体版本时,可以通过以下方法来查询: NPM 官网查询: 访问 npm 官网,在搜索框中输入你想查询的包名。在包的页面上,你可以看到所有发布过的版本号&#xff…

[leetcode] 63. 不同路径 II

文章目录 题目描述解题方法动态规划java代码复杂度分析 相似题目 题目描述 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为…

PHP ASCII码的字符串用mb_convert_encoding 转utf-8之后不生效

检测数据类型是ascii,转码之后再检测还是utf-8没生效 private function toUTF8($str){$encode mb_detect_encoding($str, array("ASCII",UTF-8,"GB2312","GBK",BIG5,LATIN1));if ($encode ! UTF-8) {$str1 mb_convert_encoding($str, UTF-8, …

原生轮播图(下一页切换,附带指示器)

下面是目录结构&#xff1a; index.html <!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name"viewport" c…

迅雷永久破解

链接&#xff1a;https://pan.baidu.com/s/1ZGb1ljTPPG3NFsI8ghhWbA?pwdok7s 下载后解压 以管理员身份运行绿化.bat&#xff0c;会自动生成快捷方式&#xff0c;如果没有可以在program中运行Thunder.exe

车牌检测识别功能实现(pyqt)

在本专题前面相关博客中已经讲述了 pyqt + yolo + lprnet 实现的车牌检测识别功能。带qt界面的。 本博文将结合前面训练好的模型来实现车牌的检测与识别。并用pyqt实现界面。最终通过检测车牌检测识别功能。 1)、通过pyqt5设计界面 ui文件如下: <?xml version="1…

基于树的时间序列预测(LGBM)

在大多数时间序列预测中&#xff0c;尽管有Prophet和NeuralProphet等方便的工具&#xff0c;但是了解基于树的模型仍然具有很高的价值。尤其是在监督学习模型中&#xff0c;仅仅使用单变量时间序列似乎信息有限&#xff0c;预测也比较困难。因此&#xff0c;为了生成足够的特征…

Docker容器:Docker-Consul 的容器服务更新与发现

目录 前言 一、什么是服务注册与发现 二、 Docker-Consul 概述 1、Consul 概念 2、Consul 提供的一些关键特性 3、Consul 的优缺点 4、传统模式与自动发现注册模式的区别 4.1 传统模式 4.2 自动发现注册模式 5、Consul 核心组件 5.1 Consul-Template组件 5.2 Consu…

ICML 2024有何亮点?9473篇论文投稿,突破历史记录

会议之眼 快讯 2024年5月1日&#xff0c;第42届国际机器学习大会ICML 2024放榜啦&#xff01;录用率27.5%&#xff01;ICML 2024的录用结果受到了广泛的关注&#xff0c;本届会议的投稿量达到了9473篇&#xff0c;创下了历史新高&#xff0c;比去年的6538篇增加了近3000篇&…

C/C++开发环境配置

配置C/C开发环境 1.下载和配置MinGW-w64 编译器套件 下载地址&#xff1a;https://sourceforge.net/projects/mingw-w64/files/mingw-w64/mingw-w64-release/ 下载后解压并放至你容易管理的路径下&#xff08;我是将其放在了D盘的一个software的文件中管理&#xff09; 2.…