[PyTorch][chapter 54][Variational Auto-Encoder 实战]

news2025/1/10 16:41:05

前言:

   
 

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

 

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

    vae.py

   main.py

目录:

  1.      vae
  2.       main

一  vae

  文件名: vae.py

   作用:   Variational Autoencoders (VAE)

 训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023

@author: chengxf2
"""

import torch
from torch import nn

#ae: AutoEncoder

class VAE(nn.Module):
    
    def __init__(self,hidden_size=20):
        
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(in_features=784, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=hidden_size),
            nn.ReLU()
            )
         # hidden [batch_size, 10]
         
        h_dim = int(hidden_size/2)
        self.hDim = h_dim

        self.decoder = nn.Sequential(
             nn.Linear(in_features=h_dim, out_features=64),
             nn.ReLU(),
             nn.Linear(in_features=64, out_features=128),
             nn.ReLU(),
             nn.Linear(in_features=128, out_features=256),
             nn.ReLU(),
             nn.Linear(in_features=256, out_features=784),
             nn.Sigmoid()
             )
        
        
    def forward(self, x):
            '''
            param x:[batch, 1,28,28]
            return 
        
            '''
      
            batchSz= x.size(0)
            #flatten
            x = x.view(batchSz, 784)
            
            #encoder
            h= self.encoder(x)
     
            #在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigma
            u, sigma = h.chunk(2,dim=1)
            
            #Reparameterize trick:
            #randn_like:产生一个正太分布 ~ N(0,1)
            #h.shape [batchSize,self.hDim]
            h = u+sigma* torch.randn_like(sigma)
           
            #kld :1e-8 防止sigma 平方为0
            kld = 0.5*torch.sum(
                torch.pow(u,2)+
                torch.pow(sigma,2)-
                torch.log(1e-8+torch.pow(sigma,2))-
                1
                )
            
            #MSE loss 是平均loss, 所以kld 也要算一个平均值
            kld = kld/(batchSz*32*32)
            xHat =   self.decoder(h)
            
            #reshape
            xHat = xHat.view(batchSz,1,28,28)
            
            return xHat,kld
        
    


二 main

文件名: main.py

作用: 训练,测试数据集

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023

@author: chengxf2
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdom





def main():
   
   batchNum = 32
   lr = 1e-3
   epochs = 20
   device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
   torch.manual_seed(1234)
   viz = visdom.Visdom()
   viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))

    
   

   tf= transforms.Compose([ transforms.ToTensor()])
   mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
   train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
   
   mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
   test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
   global_step =0

   
   

   
  
   model =VAE().to(device)
   criteon = nn.MSELoss().to(device) #损失函数
   optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
   
   print("\n ----main-----")
   for epoch in range(epochs):
       
       start = time.perf_counter()
       for step ,(x,y) in enumerate(train_data):
           #[b,1,28,28]
           x = x.to(device)
           x_hat,kld = model(x)
           
           loss = criteon(x_hat, x)
           
           if kld is not None:
              
               
               elbo = -loss -1.0*kld
               loss = -elbo
           #backprop
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
           viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
           global_step +=1



    
       end = time.perf_counter()    
       interval = int(end - start)
  
       print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())
       
       x,target = iter(test_data).next()
       x = x.to(device)
       with torch.no_grad():
           x_hat,kld = model(x)
       
       tip = 'hat'+str(epoch)
       viz.images(x,nrow=8, win='x',opts=dict(title='x'))
       viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
           
           
           
           
   

if __name__ == '__main__':
    
    main()

 参考:

 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/956649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

明年起,企业数据将作为资产被纳入会计报表

数据,是数字化经济时代的生产要素,是企业重要的资产,是企业发展经营的重要依据。为了规范企业数据资源相关会计处理,强化相关会计信息披露,近日财政部制定印发了《企业数据资源相关会计处理暂行规定》(以下…

RealSense D455启动教程

环境: ubuntu20.04 ros:noetic 视觉传感器:Intel RealSense D455 通过命令安装不成功后改为下面源码安装 1. 安装Intel RealSense SDK 2.0 1.1源码安装 1. 下载源码git clone https://github.com/IntelRealSense/librealsense cd librealsense…

OpenCV c++ 使用imshow显示灰色窗口

OpenCV使用imshow显示灰色窗口 原因是使用了system(‘pause’);函数,只需要将该函数去掉,使用opencv中的对应函数 waitKey(0) 即可实现同样效果。 system(“pause”); 改为: cv::waitKey(0); 显示效果:

【初识Git工具】Git工具的基本介绍

【初识Git工具】Git工具的基本介绍 一、什么是Git?1.1 Git简介1.2 Git和SVN区别1.3 常用的Git工具二、Git的起源三、Git的优点四、Git的架构五、Git的基本概念5.1 仓库(Repository)5.2 版本(Commit)5.3 分支(Branch)5.4 合并(Merge)5.5 标签(Tag)六、Git的基本使用命…

引用(个人学习笔记黑马学习)

1、引用的基本语法 #include <iostream> using namespace std;int main() {int a 10;//创建引用int& b a;cout << "a " << a << endl;cout << "b " << b << endl;b 100;cout << "a "…

大模型综述论文笔记6-15

这里写自定义目录标题 KeywordsBackgroud for LLMsTechnical Evolution of GPT-series ModelsResearch of OpenAI on LLMs can be roughly divided into the following stagesEarly ExplorationsCapacity LeapCapacity EnhancementThe Milestones of Language Models Resources…

Scala的特质trait与java的interface接口的区别,以及Scala特质的自身类型和依赖注入

1. Scala的特质trait与java接口的区别 Scala中的特质&#xff08;trait&#xff09;和Java中的接口&#xff08;interface&#xff09;在概念和使用上有一些区别&#xff1a; 默认实现&#xff1a;在Java中&#xff0c;接口只能定义方法的签名&#xff0c;而没有默认实现。而在…

Android基础之Activity生命周期(一)

Activity是Android四大组件之一、称为之首也恰如其分。 Activity直接翻译为中文叫活动。在Android系统中Activity就是我看到的一个完整的界面。 界面中看到的TextView(文字)、Button(按钮)、ImageView(图片)都是需要Activity来承载的。 总结一句话,Activity负责界面的呈…

面试官如何考察与CAP相关的理论?

在互联网技术面试中&#xff0c;考察分布式技术已经是面试的标配了。很多招聘信息中&#xff0c;你能发现&#xff0c;一线互联网公司在对候选人的要求中都有“分布式系统设计”这一关键词。无论你是程序员&#xff0c;还是架构师&#xff0c;都要掌握分布式系统设计。 案例背…

深度探索JavaScript中的原型链机制

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;全栈领域新星创作者✌&#xff0c;CSDN博客专家&#xff0c;阿里云社区专家博主&#xff0c;2023年6月csdn上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责…

手游排行前十名,手游排行榜2023前十名

今天为大家带来手游排行前十名&#xff0c;如今流行的手机游戏专注于在画面和游戏性方面为玩家提供更逼真、更流畅的游戏体验。在画面方面&#xff0c;手游开发商经常使用先进的游戏引擎和技术来提高游戏的图形质量和细节&#xff0c;以及增加游戏的动态照明和物理效果&#xf…

day-37 代码随想录算法训练营(19)贪心part06

738.单调递增的数字 思路&#xff1a;在给的数字中找到第一个开始递减的两个数字 &#xff1b; 将前一个数字减1 &#xff1b; 后面的数字全部变为最大值9 968.监控二叉树 思路&#xff1a;分三种状态&#xff1a;0无覆盖 1有监控 2有覆盖 分四种情况&#xff1a;1.两…

一、C#—概述环境安装(1)

&#x1f33b;&#x1f33b; 目录 一、 C#概述1.1 为啥学习C#1.2 TIBOE编程语言排行榜1.3 IEEE编程语言排行榜1.4 什么是C#1.5 C#创始人1.6 C#发展历史1.7 C#特点1.8 C#与Java1.9 .NET Framework1.10 C# 与 .NET Framework1.11 C#得应用领域1.12 C#能做什么 二、开发环境得安装…

go学习part21 Redis和Go(2)

1.三方库安装 309_尚硅谷_Go连接到Redis_哔哩哔哩_bilibili 借鉴&#xff1a; Golang 安装 Redis_go fiber 安装redis_柒柒伍贰玖。的博客-CSDN博客 三方redis库已经迁移到以下网址&#xff0c;go get github.com/gomodule/redigo/redis gomodule/redigo: Go client for Red…

【前端学习记录】neffos插件与控制台交互

背景 最近项目上有个需求需要用到websocket&#xff0c;于是就学了一下关于websocket的使用方法。不过由于后台使用的框架限制&#xff0c;需要前后端一起使用neffos插件&#xff0c;中间踩了很多的坑&#xff0c;这里简单记录一下。 websocket WebSocket 是一种在客户端和服…

计算机视觉-YOYO-

目录 计算机视觉-YOYO 目标检测发展历程 区域卷积神经网络(R-CNN) Fast R-CNN Mask R-CNN模型 比如SSD、YOLO(1, 2, 3)、R-FCN 目标检测基础概念 边界框、锚框和交并比 边界框&#xff08;bounding box&#xff09; 锚框&#xff08;Anchor box&#xff09; 交并比 …

win10 ping不通 Docker ip(解决截图)

背景&#xff1a; win10下载了docker desktop就是这个图&#xff0c;然后计划做一个springboot连接docker。 docker部署springboot :docker 部署springboot(成功、截图)_總鑽風的博客-CSDN博客 问题&#xff1a;spring boot部署docker后&#xff0c;docker接口通了&#xff0…

并发编程的故事——Java线程

Java线程 文章目录 Java线程一、线程创建二、线程运行三、线程运行四、主线程和守护线程五、线程的五种状态六、线程的六种状态七、烧水泡茶案例 一、线程创建 创建线程方法一&#xff1a; Thread重写run方法 Slf4j(topic "c.MyTest1") public class MyTest1 {publ…

Linux系统Ubuntu配置Docker详细流程

本文介绍在Linux操作系统Ubuntu的18.04及以上版本中&#xff0c;配置开源容器化平台和工具集Docker的详细方法&#xff1b;其中&#xff0c;我们以配置Docker平台的核心组件之一——Docker Engine为例来详细介绍。 首先&#xff0c;大家需要明确&#xff0c;我们常说的Docker&a…

设计模式-原型模式详解

文章目录 前言理论基础1. 原型模式定义2. 原型模式角色3. 原型模式工作过程4. 原型模式的优缺点 实战应用1. 原型模式适用场景2. 原型模式实现步骤3. 原型模式与单例模式的区别 原型模式的变体1. 带有原型管理器的原型模式2. 懒汉式单例模式的原型模式实现3. 细粒度原型模式 总…