Pytorch Advanced(一) Generative Adversarial Networks

news2024/10/5 22:22:04

生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了

参考

1、AI作家
2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节;
3、进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。

那到底是怎么实现的呢?


GAN中有两大组成部分G和D

G是generator,生成器: 负责凭空捏造数据出来

D是discriminator,判别器: 负责判断数据是不是真数据

示例图如下:

给一个随机噪声z,通过G生成一张假图,然后用D去分辨是真图还是假图。假设G生成了一张图,在D那里的得分很高,那么G就很成功的骗过了D,如果D很轻松的分辨出了假图,那么G的效果不好,那么就需要调整参数了。


G和D是两个单独的网络,那么他们的参数都是训练好的吗?并不是,两个网络的参数是需要在博弈的过程中分别优化的。

下面就是一个训练的过程:

GAN在一轮反向传播中分为两步,先训练D在训练G。

训练D时,上一轮G产生的图片,和真实图片一起作为x进行输入,假图为0,真图标签为1,通过x生成一个score,通过score和标签y计算损失,就可以进行反向传播了。

训练G时,G和D是一个整体,取名为D_on_G。输入随机噪声,G产生一个假图,D去分辨,score = 1就是需要我们需要优化的目标,意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的,我们通过反向传播来优化G,让他生成更加真实的图片。这就好比:如果你参加考试,你别指望能改变老师的评分标准


GAN无监督学习,(cGAN是有监督的),以后会学习的。怎么理解无监督学习呢?这里给的真图是没有经过人工标注的,只知道这是真的,D是不知道这是什么的,只需要分辨真假。G也不知道生成了什么,只需要学真图去骗D。


具体如何实施呢?

import os
import torch
import torchvision
import torch.nn as nn 
from torchvision import transforms
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

注意这里有个归一化的过程,MNIST是单通道,但是如果mean=(0.5,0.5,0.5)会报错,因为是对3通道操作 。

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5,),   # 3 for RGB channels
                                     std=(0.5,))])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)

定义生成器和判别器:

生成器:可以看到输入的维度为64,是一组噪声图像,通过生成器将特征扩大到了MNIST图像大小784。

判别器:输入维度为图像大小,最后输出特征个数为1,采用sigmoid激活(不用softmax的)

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())


# Generator 
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

 重点看训练部分,我们到底是如何来训练GAN的。

判别器部分:判别器的损失值分为两部分,(一)将mini_batch定义为正样本,告诉他我是正品,所以设置标签为1。优化判别器判断正品的能力;(二)生成一幅赝品,再给判别器判别,这时候赝品的标签为0,优化判断赝品的能力。所以总损失为这两部分之和,计算梯度,优化判别器参数。

G_on_D:输入一个噪声,让生成器生成一幅图像,然后让D去判别,计算和正品之间的距离,即损失。反向传播,优化G的参数。

# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

训练完了怎么用?

只要用我们的生成器就可以随意生成了。

import matplotlib.pyplot as plt
z = torch.randn(1,latent_size).to(device)
output = G(z)
plt.imshow(output.cpu().data.numpy().reshape(28,28),cmap='gray') 
plt.show()

 下面就是随机生成的图像了!

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1002305.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript Promise 的真正工作原理

Promise 是处理异步代码的一种技术,也称为脱离回调地狱的头等舱门票。 3 承诺状态 待定状态 已解决状态 拒绝状态 理解 JavaScript Promis 什么是承诺? 通常,承诺被定义为最终可用的值的代理。 Promise 多年来一直是 JavaScript 的一部分(在 ES2015 中标准化并引入)。最…

【数据结构】前言概况 - 树

🚩纸上得来终觉浅, 绝知此事要躬行。 🌟主页:June-Frost 🚀专栏:数据结构 🔥该文章针对树形结构作出前言,以保证可以对树初步认知。 目录: 🌍前言:&#x1f3…

Python语义分割与街景识别(4):程序运行

前言 本文主要用于记录我在使用python做图像识别语义分割训练集的过程,由于在这一过程中踩坑排除BUG过多,因此也希望想做这部分内容的同学们可以少走些弯路。 本文是python语义分割与街景识别第四篇,关于程序的内容,也是差不多最…

【Unity编辑器扩展】| GameView面板扩展

前言【Unity编辑器扩展】| GameView面板扩展未运行时在Game视图进行绘制总结前言 前面我们介绍了Unity中编辑器扩展的一些基本概念及基础知识,还有编辑器扩展中用到的相关特性Attribute介绍。后面就来针对Uniity编辑器扩展中比较常用的模块进行学习介绍。本文就来详细介绍一下…

JAVA版的数据结构——链表

目录 1.单向不带头链表 1.1 链表的概念及结构 1.2 代码部分 1.3 完整的全部代码 2. 双向不带头链表 2.1 代码部分 2.2 完整的代码 3. MySingleList与MyLinkedList代码上的区别 4. LinkedList的使用 4.1 什么是LinkedList 4.2 LinkedList的使用 4.2.1 LinkedList的构…

【数据结构】堆的向上调整和向下调整以及相关方法

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃 文章目录 一、堆的概念二、堆的性质…

github上创建分支并合并到master

github上创建分支并合并到master 目录概述需求: 设计思路实现思路分析1.创建分支2.commit changes3.create pull request按钮4.网页解析器5.数据处理器 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,ful…

[deeplearning]深度学习框架torch的概念以及数学内容

(提前声明:这边的操作系统为ubuntn22.04,至于window上如何进行安装和导入按这边不是很理解) (另外代码样例基本不使用notebook,paddle等等在线工具,而是使用本机安装好的python环境,和pytorch框…

IDEA中maven的设置以及相关功能

Maven 项目介绍 学习前提 相对于传统的项目,Maven 下管理和构建的项目真的非常好用和简单,所以这里也强调下,尽量使用此类工具进行项目构建。 ## Maven 常用设置介绍 如上图标注 1 所示,我们可以指定我们本地 Maven 的安装目录…

模块化开发_groupby查询think PHP5.1

要求按照分类的区别打印出不同类别的数据计数 如张三,做了6件事情 这里使用原生查询先测试 SELECT cate_id, COUNT(*) AS order_count FROM tp_article GROUP BY cate_id;成功 然后项目中实现 public function ss(){$sql "SELECT cate_id, COUNT(*) AS orde…

RCNA 锐捷培训

第一章 网络基础入门 1.1 OSI参考模型及TCP/IP协议栈 数据是如何传输的? 数据在计算机网络中传输通常依赖于TCP/IP协议模型。 什么是网络? 网络是一种连接多个计算机、设备或系统的通信基础设施,其目的是实现资源共享、信息传递、接收和共享…

14.Xaml ProgressBar控件 进度条控件

1.运行效果 2.运行源码 a.Xaml源码 <Grid Name="Grid1"><!--Orientation="Horizontal" 进度条的方向 水平的还是垂直的Value="40" 进度的数值Minimum="0" 最小值Maximum

17. 线性代数 - 矩阵的逆

文章目录 矩阵的转置矩阵的逆Hi, 您好。我是茶桁。 我们已经学习过很多关于矩阵的知识点,今天依然还是矩阵的相关知识。我们来学一个相关操作「矩阵的转置」,更重要的是我们需要认识「矩阵的逆」 矩阵的转置 关于矩阵的转置,咱们导论课里有提到过。转置实际上还是蛮简单…

淘宝京东扣库存怎么实现的

1. 使用kv存储实时的库存&#xff0c;直接在kv里扣减&#xff0c;避免用分布式锁 2. 不要先查再扣&#xff0c;直接扣扣扣&#xff0c;扣到负数&#xff0c;&#xff08;增改就直接在kv里做&#xff09;&#xff0c;就说明超卖了&#xff0c;回滚刚才的扣减 3. 同时写MQ&…

小白也可以玩转CMake之常用必备

目录 1.设置编译器flags2.设置源文件属性3.链接器标志4.Debug与Release包 今天&#xff0c;分享一篇工作中经常用到的一些CMake命令&#xff0c;看完就学会了哦&#xff0c;更多CMake与C内容也期待加入星球与我一起学习呀~ 1.设置编译器flags 例如&#xff1a;设置C标准&#x…

论文笔记《3D Gaussian Splatting for Real-Time Radiance Field Rendering》

项目地址 原论文 Abstract 最近辐射场方法彻底改变了多图/视频场景捕获的新视角合成。然而取得高视觉质量仍需神经网络花费大量时间训练和渲染&#xff0c;同时最近较快的方法都无可避免地以质量为代价。对于无边界的完整场景&#xff08;而不是孤立的对象&#xff09;和 10…

C高级day4循环语句

1&#xff0c;思维导图 运行结果为&#xff1a; 运行结果为&#xff1a;

【基础计算机网络1】认识计算机网络体系结构,了解计算机网络的大致模型(下)

前言 在上一篇我们主要介绍了有关计算机网络概述的内容&#xff0c;下面这一篇我们将来介绍有关计算机网络体系结构与参考模型的内容。这一篇博客紧紧联系上一篇博客。 这一篇博客主要内容是&#xff1a;计算机网络体系结构与参考模型&#xff0c;主要是计算机网络分层结构、协…

search_engine:搜索引擎实现

目录 一.项目背景及原理 1.背景 2.原理 二.技术栈及项目环境 1.技术栈 2.项目环境 3.环境准备 三.模块划分 四. 遇到的问题及其解决方法 1.搜索结果出现重复文档的问题 2.实现httplib功能的问题 五. 项目特点 1.文档记录 2.竞价排名 3.去掉暂停词 4.模拟实现http…

云优先已死——云智能正在发生

混合云&#xff0c;即一些本地云和一些异地云&#xff0c;已经成为 IT 的默认架构&#xff0c;并且已经存在了一段时间了。然而&#xff0c;到目前为止&#xff0c;混合动力一直被视为通向完全公有云的过程中的过渡状态&#xff0c;许多人可能会居高临下地称之为“云成熟度”。…