GAN原理 代码解读

news2024/11/17 12:39:34

模型架构

在这里插入图片描述

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch

# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([
    transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)
    transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',
                               train=True,
                               transform=transform,
                               download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''
    输入:正态分布随机数噪声(长度为100)
    输出:生成的图片,(1,28,28)
    中间过程:
        linear1: 100 -> 256
        linear2: 256 -> 512
        linear3: 512 -> 28*28
        reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数
        self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),
                                   nn.Linear(256,512),nn.ReLU(),
                                    # 最后一层用tanh激活,将数据压缩到-1到1
                                   nn.Linear(512,28*28),nn.Tanh())
    def forward(self,x):
        img = self.model(x)
        img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)
        return img

定义判别器

'''
    判别器
    输入:(1,28,28)的图片
    输出:二分类的概率值 用sigmoid压缩到0-1之间
    内容:
    判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28,512),nn.LeakyReLU(),
            nn.Linear(512,256),nn.LeakyReLU(),
            nn.Linear(256,1),nn.Sigmoid(),
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.model(x)
        return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

def gen_img_plot(model,epoch,test_input):
    prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpy
    prediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

显示图片的函数

def img_plot(img):
    img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

定义训练函数


def train(num_epoch,test_input):
    D_loss = []
    G_loss = []
    # 训练循环
    for epoch in range(num_epoch):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(dataloader) # 返回批次数
        for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)
            # print(f'step={step},img.shape={img.shape}')
            # img_plot(img)
            img = img.to(device)
            size = img.size(0) # 得到一个批次的图片
            random_noise = torch.randn(size,100,device=device) # 生成器的输入

            '''一. 训练判别器'''
            '''用真实图片训练判别器'''
            dis_optim.zero_grad()
            real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果
            # 判别器在真实图像上的损失
            d_real_loss = bce_loss(real_output,
                                   # torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。
                                   torch.ones_like(real_output))
            d_real_loss.backward()

            '''用生成的图片训练判别器'''
            gen_img = gen(random_noise)
            # 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensor
            fake_output = dis(gen_img.detach())
            d_fake_loss = bce_loss(fake_output,
                                   torch.zeros_like(fake_output))
            d_fake_loss.backward()
            d_loss = d_real_loss+d_fake_loss
            dis_optim.step() # 对参数进行优化

            '''二.训练生成器'''
            gen_optim.zero_grad()
            # 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别
            fake_output = dis(gen_img)
            # 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。
            # 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。
            g_loss = bce_loss(fake_output,
                              torch.ones_like(fake_output))
            g_loss.backward()
            gen_optim.step()

            # 计算一个epoch的损失
            with torch.no_grad(): #  禁止梯度计算和参数更新
                d_epoch_loss +=d_loss
                g_epoch_loss +=g_loss
        # 计算整体loss每个epoch的平均Loss
        with torch.no_grad(): #  禁止梯度计算和参数更新
            d_epoch_loss /= count
            g_epoch_loss /= count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)
            print('Epoch:', epoch+1)
            print(f'd_epoch_loss={d_epoch_loss}')
            print(f'g_epoch_loss={g_epoch_loss}')
            # 将16个长度为100的噪音输入到生成器并画图
            gen_img_plot(gen,test_input)

开始训练

'''开始计时'''
start_time = time.time()

'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')

'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
    print(f'{round(run_time,2)}s')
else:
    print(f'{round(run_time/60,2)}minutes')

结果可视化

在这里插入图片描述

加载训练好的参数

gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

在这里插入图片描述
GAN的生成是随机的,不同的噪声,生成不同的数字

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/928922.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

docker compose iceberg 快速体验

https://iceberg.apache.org/spark-quickstart/#docker-compose port&#xff1a;8888

回归预测 | MATLAB实现GA-APSO-IBP改进遗传-粒子群算法优化双层BP神经网络多输入单输出回归预测

回归预测 | MATLAB实现GA-APSO-IBP改进遗传-粒子群算法优化双层BP神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现GA-APSO-IBP改进遗传-粒子群算法优化双层BP神经网络多输入单输出回归预测效果一览基本介绍模型描述程序设计参考资料 效果一览 基本介绍 MATLAB实现GA-…

DPU在东数西算背景下如何赋能下一代算力基础设施 中科驭数在未来网络发展大会论道

以ChatGPT为代表的人工智能大模型的快速发展&#xff0c;对网络信息技术创新发展提出了新的挑战&#xff0c;我国东数西算重大工程也在加速布局。以确定性网络、算力网络为代表的未来网络核心技术&#xff0c;正成为决定未来经济和产业发展的关键。 8月23日&#xff0c;第七届…

学生分班查询系统的创建与使用指南

开学季&#xff0c;负责分班工作的老师们又面临一个难题&#xff1a;如何公布分班结果&#xff1f;将结果放在学校官网上可能会让很多无关人员看到&#xff0c;而不放则会导致家长们纷纷打电话来询问。那么&#xff0c;有没有一种方法可以让家长们自行查看分班结果呢&#xff1…

【SLAM】光流 - LK光流 - 金字塔分层LK光流

在SLAM的视觉里程计中&#xff0c;比较常用的就是特征点法和直接法。而直接法中&#xff0c;光流则是其中的重点内容&#xff0c;比如LSD-SLAM中就使用到了光流的方法。本文将会就光流的理论原理、公式推导进行详细的剖析&#xff0c;以帮助读者深刻地理解。 光流算法 光流是关…

每日一练 | 华为认证真题练习Day103

1、网络设备发送的IPv6报文时&#xff0c;会首先将报文长度和NTU值进行对比&#xff0c;如果大于MTU值&#xff0c;则直接丢弃。 A. 对 B. 错 2、路由器接口输出信息如下&#xff0c;则此接口可以接收哪些组播地址的数据&#xff1f; &#xff08;多选&#xff09; A. FF02::…

中国储能行业研究报告,光伏和风电领域装机量迅速增长

随着科学技术的进步&#xff0c;储能工业对我们的生活产生了深远的影响。电池技术的突破使得手机使用寿命更长&#xff0c;家庭储能系统使得能源管理更加智能和高效。人们通过对于储能的需求进行不断发展增长&#xff0c;将目光投向更环保可持续的解决问题方案。这个行业的发展…

计算机丢失msvcp140.dll是什么意思,要怎么处理呢?

今天&#xff0c;我将和大家探讨一个关于计算机的问题——“计算机丢失msvcp140.dll是什么意思&#xff0c;要怎么处理呢&#xff1f;”这个问题可能会在很多使用计算机的朋友中遇到。希望通过今天的演讲&#xff0c;能够帮助大家解决这个困扰。 首先&#xff0c;我们来了解一…

DevOps中的持续测试优势和工具

持续测试 DevOps中的持续测试是一种软件测试类型&#xff0c;它涉及在软件开发生命周期的每个阶段测试软件。持续测试的目标是通过早期测试和经常测试来评估持续交付过程的每一步的软件质量。 DevOps中的持续测试流程涉及开发人员、DevOps、QA和操作系统等利益相关者。 持续…

CC++ 常用技巧

C 中的C C 是面向过程的是把整个大程序分为一个个的子函数&#xff1b;C 是面向对象的是把整个程序划分为一个个的类。C 是完全兼容C 的&#xff0c;C 是C 的子集&#xff0c;C 是C 的超集。C 又对C 做了很多补充和提升&#xff0c;因此使用C 会比使用纯C 更方便。混用C和C&am…

《软件开发的201个原则》阅读笔记 120-161条

目录 使用有效的测试完成度标准 原则122 达成有效的测试覆盖 原则123 不要在单元测试之前集成 原则 124 测量你的软件 原则125 分析错误的原因 对错不对人 原则127 好的管理比好的技术更重要 使用恰当的方法 原则 129 不要相信你读到的一切 原则130 理解客户的优先级 原…

千人千面的分析?SpeedBI数据可视化工具也很擅长

SpeedBI数据可视化工具可以实现千人千面的分析&#xff0c;通过个性化的数据展示和交互式分析功能&#xff0c;让每个人都可以根据自己的需求和业务背景进行数据分析和可视化。 SpeedBI数据可视化工具支持多维自助分析&#xff0c;可以帮助用户深入探索和分析数据。以下是Spee…

超店有数最新报告!美国TikTok小店全新洗牌?搏一把的机会到了

据传&#xff0c;TikTok美国市场的半闭环模式将于8月底关闭&#xff0c;其将在美国全力发展全闭环。也就是说&#xff0c;想要继续在TikTok美区卖货&#xff0c;必须开通TikTok小店&#xff0c;官方不给放外链了。 如果消息属实&#xff0c;全闭环模式开启&#xff0c;美国Tik…

抖音电商,从消费者体验中做增量

夜晚总是最容易emo&#xff0c;也最容易冲动的时候。 王雪临睡前刷着抖音&#xff0c;看到一家化妆品品牌在直播&#xff0c;刚好最近她想买抗老精华&#xff0c;点进去听主播小姐姐介绍一番后下了单。第二天早上起来犹豫要不要退货&#xff0c;再货比三家时&#xff0c;手机收…

stm32之DHT11

今天&#xff0c;记录一下DHT11&#xff0c;涉及到了单总线协议&#xff0c;所以先花点时间谈论一下单总线协议&#xff08;DS18B20也是用的单总线&#xff09;。 单总线协议 单总线技术的通信协议 可能这时序图就是个例子&#xff0c;ds18b20的时序图与DHT11的时序图也是不一…

服务器中了mkp勒索病毒该怎么办?勒索病毒解密,数据恢复

mkp勒索病毒算的上是一种比较常见的勒索病毒类型了。它的感染数量上也常年排在前几名的位置。所以接下来就由云天数据恢复中心的技术工程师来对mkp勒索病毒做一个分析&#xff0c;以及中招以后应该怎么办。 一&#xff0c;中了mkp勒索病毒的表现 桌面以及多个文件夹当中都有一封…

mysql基础——认识索引

一、介绍 “索引”是为了能够更快地查询数据。比如一本书的目录&#xff0c;就是这本书的内容的索引&#xff0c;读者可以通过在目录中快速查找自己想要的内容&#xff0c;然后根据页码去找到具体的章节。 二、优缺点 优势&#xff1a;以快速检索&#xff0c;减少I/O次数&am…

TMP: 利用std::tuple完成运行期的if...else替换

code client code 参考链接&#xff1a; std::tuple std::tuple_size std::tuple_element

接口测试-快问快答你能做对几道【含答案】

1、做接口测试当请求参数多时tps下降明显&#xff0c;此接口根据参数从redis中获取数据&#xff0c;每个参数与redis交互一次&#xff0c;当一组参数是tps5133&#xff0c;五组参数是tps1169&#xff0c;多次交互影响了处理性能&#xff0c;请详细阐述如何改进增进效果的方案。…

AD(第二部分---绘制原理图库及编译检查)

设计电路-----器件选型----绘制原理图----->先有"BOOM"&#xff0c;后更改AD封装 10.元件的放置&#xff1a; 当有多个元件库&#xff0c;选择某一个时&#xff0c;需要点击右下角"Panels"&#xff0c;之后点击Components。如下图&#xff1a; 之后双击…