[PyTorch][chapter 53][Auto Encoder 实战]

news2025/1/10 11:15:58

前言:

     结合手写数字识别的例子,实现以下AutoEncoder

     ae.py:  实现autoEncoder 网络

     main.py: 加载手写数字数据集,以及训练,验证,测试网络。

左图:原图像

右图:重构图像

 ----main-----

 每轮训练时间 : 91
0 loss: 0.02758789248764515

 每轮训练时间 : 95
1 loss: 0.024654878303408623

 每轮训练时间 : 149
2 loss: 0.018874473869800568

目录:

     1: AE 实现

     2: main 实现


一  ae(AutoEncoder) 实现

  文件名: ae.py

               模型的搭建

   注意点:

            手写数字数据集 提供了 标签y,但是AutoEncoder 网络不需要,

它的标签就是输入的x, 需要重构本身

自编码器(autoencoder, AE)是一类在半监督学习和非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning) [1-2]  。

自编码器包含编码器(encoder)和解码器(decoder)两部分 [2]  。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 [2]  。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。

自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection) [2]  。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising) [3]  、神经风格迁移(neural style transfer)等 [4]  。

   

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023

@author: chengxf2
"""

import torch
from torch import nn

#ae: AutoEncoder

class AE(nn.Module):
    
    def __init__(self,hidden_size=10):
        
        super(AE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(in_features=784, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=hidden_size),
            nn.ReLU()
            )
         # hidden [batch_size, 10]

        self.decoder = nn.Sequential(
             nn.Linear(in_features=hidden_size, out_features=64),
             nn.ReLU(),
             nn.Linear(in_features=64, out_features=128),
             nn.ReLU(),
             nn.Linear(in_features=128, out_features=256),
             nn.ReLU(),
             nn.Linear(in_features=256, out_features=784),
             nn.Sigmoid()
             )
        
        
    def forward(self, x):
            '''
            param x:[batch, 1,28,28]
            return 
        
            '''
      
            m= x.size(0)
            
            x = x.view(m, 784)
            
            hidden= self.encoder(x)
            x =   self.decoder(hidden)
            
            #reshape
            x = x.view(m,1,28,28)
            
            return x
        
    


二 main 实现

  文件名: main.py

  作用:

      加载数据集

     训练模型

     测试模型泛化能力

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023

@author: chengxf2
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from ae import AE
import visdom





def main():
   
   batchNum = 32
   lr = 1e-3
   epochs = 20
   device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
   torch.manual_seed(1234)
   viz = visdom.Visdom()
   viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))

    
   

   tf= transforms.Compose([ transforms.ToTensor()])
   mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
   train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
   
   mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
   test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
   global_step =0

   
   

   
  
   model =AE().to(device)
   criteon = nn.MSELoss().to(device) #损失函数
   optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
   
   print("\n ----main-----")
   for epoch in range(epochs):
       
       start = time.perf_counter()
       for step ,(x,y) in enumerate(train_data):
           #[b,1,28,28]
           x = x.to(device)
           x_hat = model(x)
           
           loss = criteon(x_hat, x)
           
           #backprop
           optimizer.zero_grad()
           loss.backward()
           optimizer.step()
           viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
           global_step +=1



    
       end = time.perf_counter()    
       interval = end - start
       print("\n 每轮训练时间 :",int(interval))
       print(epoch, 'loss:',loss.item())
       
       x,target = iter(test_data).next()
       x = x.to(device)
       with torch.no_grad():
           x_hat = model(x)
       
       tip = 'hat'+str(epoch)
       viz.images(x,nrow=8, win='x',opts=dict(title='x'))
       viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
           
           
           
           
   

if __name__ == '__main__':
    
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/949469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DHCP 服务器部署

| DHCP - - > Dynamic Host Configuration Protocol 动态主机配置协议 背景 任何一个需要上网的设备,都必须得有IP地址,子网掩码,网关,等等网络参数。比如:手机,电脑,智能手表&#xff0c…

【SpringBoot学习笔记】02.静态资源与首页订制

静态资源 Spring Boot 通过 MVC 的自动配置类 WebMvcAutoConfiguration 为这些 WebJars 前端资源提供了默认映射规则,部分源码如下。 jar包: JAR 文件就是 Java Archive File,顾名思意,它的应用是与 Java 息息相关的,…

IDEA 报 Cannot resolve symbol ‘HttpServletResponse‘ 解决

springboot2版本换成springboot3之后,代码这里突然报红了, 首先要淡定,把原先Import的引入删掉,重新引入试试呢,是不是很简单哈哈。 原来,springboot3的路径是: import jakarta.servlet.http…

Docker之私有仓库 RegistryHarbor

目录 一、Docker私有仓库(Registry) 1.1 Registry的介绍 二、搭建本地私有仓库 2.1首先下载 registry 镜像 2.2在 daemon.json 文件中添加私有镜像仓库地址 2.3运行 registry 容器 2.4Docker容器的重启策略 2.5为镜像打标签 2.6上传到私有仓库 2…

【车载雷达信号处理】利用sinc函数实现扣点

针对信号处理流程中多次FFT输出的频谱结果,在特殊的场景下,可能存在针对某一特定频点的固定"虚警",所以针对某一个特定频点进行“扣点”的操作是常有的信号处理流程需求。不仅如此,针对最大能量值的扣点也能在不适合使用…

文件修改时间能改吗?怎么改?

文件修改时间能改吗?怎么改?修改时间是每个电脑文件具备的一个属性,它代表了这个电脑文件最后一次的修改时间,是电脑系统自动赋予文件的,相信大家都应该知道。我们右击鼠标某个文件,然后点击弹出菜单里面的…

并发编程(四大函数接口) 06 详细讲解

四大函数接口 函数接口:接口中只有一个方法 Function Function函数型接口,有一个输入参数,有一个输出只要是函数型接口可以用Lambda表达式简化 函数函数型接口,有一个输入参数,有一个输出只要是函数型接口可以用lamb…

并发容器11

一 JDK 提供的并发容器总结 JDK 提供的这些容器大部分在 java.util.concurrent 包中。 ConcurrentHashMap: 线程安全的 HashMap CopyOnWriteArrayList: 线程安全的 List,在读多写少的场合性能非常好,远远好于 Vector. ConcurrentLinkedQueue: 高效的并…

element 级联选择框偏移

如图所示,选择之后,位置跑到了左上角 添加:append-to-body"false",在弹出框的定位出现问题时,可将该属性设置为 false

(第六天)初识Spring框架-SSM框架的学习与应用(Spring + Spring MVC + MyBatis)-Java EE企业级应用开发学习记录

SSM框架的学习与应用(Spring Spring MVC MyBatis)-Java EE企业级应用开发学习记录(第六天)初识Spring框架 ​ 昨天我们已经把Mybatis框架的基本知识全部学完,内容有Mybatis是一个半自动化的持久层ORM框架,深入学习编写动态SQL&a…

淘宝京东1688商品价格监控(电商价格监测API接口系列)

淘宝价格监控,电商价格监测软件目前市面有一款是最火的,针对各个平台都可以进行价格监测,很多实用功能,今天我们就来介绍一下品牌卫士这款价格24小时监测软件都有哪些功能。 一,覆盖全网全平台 天猫淘宝、闲鱼、京东、…

RK3588平台驱动调试篇 [ GPIO篇 ] - RK3588-对GPIO的操作控制

1. 简介 RK3588从入门到精通本⽂介绍Linux操作gpio⽅法开发板:ArmSoM-W3 2. GPIO配置 Rockchip Pin的ID按照 控制器(bank)端口(port)索引序号(pin) 组成 2.1 GPIO驱动介绍 驱动包括Pinctrl驱动( drivers/pinctrl/pinctrl-rockchip.c ) 和…

servlet初体验之环境搭建!!!

我们需要用到tomcat服务器,咩有下载的小伙伴看过来:如何正确下载tomcat???_明天更新的博客-CSDN博客 1. 创建普通的Java项目,并在项目中创建libs目录存放第三方的jar包。 建立普通项目 创建libs目录存放第三…

2023_Spark_实验三:基于IDEA开发Scala例子

一、创建一个空项目&#xff0c;作为整个项目的基本框架 二、创建SparkStudy模块&#xff0c;用于学习基本的Spark基础 三、创建项目结构 1、在SparkStudy模块下的pom.xml文件中加入对应的依赖&#xff0c;并等待依赖包下载完毕。 在pom.xml文件中加入对应的依赖 ​<!-- S…

模拟4~20ma电流输出的设计

文章目录 1. 原理2. 使用GP8102S或GP8212S进行设计2.1 共地型设计2.2 共源型设计2.3 其它电流需求 3. 隔离光耦电源连接方案4. 利用GP8102S实现0-40V 的可编程电压输出 1. 原理 4 ~ 20ma电流输出的目的不用多说&#xff0c;今天就简单聊一下4 ~ 20ma电流输出是怎么设计出来的&…

【AI】数学基础——概率论

随着联结主义学派的兴起&#xff0c;概率统计已经取代了数理逻辑&#xff0c;成为了人工智能研究的主流工具 数理统计的关注点是 无处不在的可能性 对随机事件发生的可能性进行规范的数学描述是概率论的公理化过程 频率学派认为先验分布式固定的&#xff0c;模型参数靠最大似…

C++day6

1. #include <iostream>using namespace std; class Animal { private:int id; public:Animal(){}Animal(int id):id(id){}virtual void show(){cout << "动物园门牌号:" << id << endl;} }; class Houzi:public Animal { private:string n…

用变压器实现德-英语言翻译【02/8】: 位置编码

一、说明 本文是“用变压器实现德-英语言翻译”系列的第二篇。它从头开始引入位置编码。然后&#xff0c;它解释了 PyTorch 如何实现位置编码。接下来是变压器实现。 二、技术背景 位置编码用于为序列中的每个标记或单词提供相对位置。阅读句子时&#xff0c;每个单词都依赖于它…

日本核污染水排海,有必要囤盐吗?

据央视新闻24日报道&#xff0c;当地时间8月24日13时&#xff0c;日本福岛第一核电站启动污水排海。消息一出&#xff0c;全球哗然。虽然事情已经过去了几天&#xff0c;但是&#xff0c;随着这一举动&#xff0c;大家就乱了阵脚&#xff0c;恐惧者有之&#xff0c;辱骂者有之&…

Nginx从入门到精通(超级详细)

文章目录 一、什么是Nginx1、正向代理2、反向代理3、负载均衡4、动静分离 二、centos7环境安装Nginx1、安装依赖2、下载安装包3、安装4、启动5、停止 三、Nginx核心基础知识1、nginx核心目录2、常用命令3、默认配置文件讲解4、Nginx虚拟主机-搭建前端静态服务器5、使用nignx搭建…