PyTorch深度学习实战(25)——自编码器

news2024/12/23 12:55:38

PyTorch深度学习实战(25)——自编码器

    • 0. 前言
    • 1. 自编码器
    • 2. 使用 PyTorch 实现自编码器
    • 小结
    • 系列链接

0. 前言

自编码器 (Autoencoder) 是一种无监督学习的神经网络模型,用于数据的特征提取和降维,它由一个编码器 (Encoder) 和一个解码器 (Decoder) 组成,通过将输入数据压缩到低维表示,然后再重构出原始数据。在本节中,我们将学习如何使用自编码器,以在低维空间表示图像,学习以较少的维度表示图像有助于修改图像,可以利用低维表示来生成新图像。

1. 自编码器

我们已经学习了通过输入图像及其相应标签训练模型来对图像进行分类,进行分类的前提是是拥有带有类别标签的数据集。假设数据集中没有图像对应的标签,如果需要根据图像的相似性对图像进行聚类,在这种情况下,自编码器可以方便地识别和分组相似的图像。
自动编码器将图像作为输入,将其存储在低维空间中,并尝试通过解码过程输出相同图像,而不使用其他标签,因此 AutoEncoder 中的 Auto 表示能够再现输入。但是,如果我们只需要简单的在输出中重现输入,就不需要神经网络了,只需要将输入简单地原样输出即可。自编码器的作用在于它能够以较低维度对图像信息进行编码,因此称为编码器(将图像信息编码至较低维空间中),因此,相似的图像具有相似的编码。此外,解码器致力于根据编码矢量重建原始图像,以尽可能重现输入图像:

自编码器

假设模型输入图像是 MNIST 手写数字图像,模型输出图像与输入图像相同。最中间的网络层是编码层,也称瓶颈层 (bottleneck layer),输入和瓶颈层之间发生的操作表示编码器,瓶颈层和输出之间的操作表示解码器。
通过瓶颈层,我们可以在低维空间中表示图像,也可以重建原始图像,换句话说,利用自编码器中的瓶颈层能够解决识别相似图像以及生成新图像的问题,具体而言:

  • 具有相似瓶颈层值(编码表示,也称潜编码)的图像可能彼此相似
  • 通过改变瓶颈层的节点值,可以改变输出图像。

2. 使用 PyTorch 实现自编码器

本节中,使用 PyTorch 构建自编码器,我们使用 MNIST 数据集训练此网络,MNIST 数据集中是一个手写数字的图像数据集,包含了 6 万个 28x28 像素的训练样本和 1 万个测试样本。

(1) 导入相关库并定义设备:

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import numpy as np
from matplotlib import pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'

(2) 指定图像转换方法:

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
    transforms.Lambda(lambda x: x.to(device))
])

通过以上代码,将图像转换为张量,对其进行归一化,然后将其传递到设备中。

(3) 创建训练和验证数据集:

trn_ds = MNIST('MNIST/', transform=img_transform, train=True, download=True)
val_ds = MNIST('MNIST/', transform=img_transform, train=False, download=True)

(4) 定义数据加载器:

batch_size = 256
trn_dl = DataLoader(trn_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

(5) 定义网络架构,在 __init__ 方法中定义了使用编码器-解码器架构的 AutoEncoder 类,以及瓶颈层的维度,latent_dimforward 方法,并打印模型摘要信息。

定义 AutoEncoder 类和包含编码器、解码器以及瓶颈层维度的 __init__ 方法:

class AutoEncoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latend_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128), nn.ReLU(True),
            nn.Linear(128, 64), nn.ReLU(True), 
            #nn.Linear(64, 12),  nn.ReLU(True), 
            nn.Linear(64, latent_dim))
        self.decoder = nn.Sequential(
            #nn.Linear(latent_dim, 12), nn.ReLU(True),
            nn.Linear(latent_dim, 64), nn.ReLU(True),
            nn.Linear(64, 128), nn.ReLU(True), 
            nn.Linear(128, 28 * 28), nn.Tanh())

定义前向计算方法 forward

    def forward(self, x):
        x = x.view(len(x), -1)
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.view(len(x), 1, 28, 28)
        return x

打印模型摘要信息:

from torchsummary import summary
model = AutoEncoder(3).to(device)
print(summary(model, (1,28,28)))

模型架构信息输出如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                  [-1, 128]         100,480
              ReLU-2                  [-1, 128]               0
            Linear-3                   [-1, 64]           8,256
              ReLU-4                   [-1, 64]               0
            Linear-5                    [-1, 3]             195
            Linear-6                   [-1, 64]             256
              ReLU-7                   [-1, 64]               0
            Linear-8                  [-1, 128]           8,320
              ReLU-9                  [-1, 128]               0
           Linear-10                  [-1, 784]         101,136
             Tanh-11                  [-1, 784]               0
================================================================
Total params: 218,643
Trainable params: 218,643
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 0.83
Estimated Total Size (MB): 0.85
----------------------------------------------------------------

从前面的输出中,可以看到 Linear: 2-5 层是瓶颈层,将每张图像都表示为一个 3 维向量;此外,解码器使用瓶颈层中的 3 维向量重建原始图像。

(6) 定义函数在批数据上训练模型 train_batch()

def train_batch(input, model, criterion, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(input)
    loss = criterion(output, input)
    loss.backward()
    optimizer.step()
    return loss

(7) 定义在批数据上进行模型验证的函数 validate_batch()

@torch.no_grad()
def validate_batch(input, model, criterion):
    model.eval()
    output = model(input)
    loss = criterion(output, input)
    return loss

(8) 定义模型、损失函数和优化器:

model = AutoEncoder(3).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)

(9) 训练模型:

num_epochs = 20
train_loss_epochs = []
val_loss_epochs = []
for epoch in range(num_epochs):
    N = len(trn_dl)
    trn_loss = []
    val_loss = []
    for ix, (data, _) in enumerate(trn_dl):
        loss = train_batch(data, model, criterion, optimizer)
        pos = (epoch + (ix+1)/N)
        trn_loss.append(loss.item())
    train_loss_epochs.append(np.average(trn_loss))

    N = len(val_dl)
    for ix, (data, _) in enumerate(val_dl):
        loss = validate_batch(data, model, criterion)
        pos = epoch + (1+ix)/N
        val_loss.append(loss.item())
    val_loss_epochs.append(np.average(val_loss))

(10) 可视化训练期间模型的训练和验证损失随时间的变化情况:

epochs = np.arange(num_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r-', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

模型监测
(11) 使用测试数据集 val_ds 验证模型:

for _ in range(5):
    ix = np.random.randint(len(val_ds))
    im, _ = val_ds[ix]
    _im = model(im[None])[0]
    plt.subplot(121)
    # fig, ax = plt.subplots(1,2,figsize=(3,3)) 
    plt.imshow(im[0].detach().cpu(), cmap='gray')
    plt.title('input')
    plt.subplot(122)
    plt.imshow(_im[0].detach().cpu(), cmap='gray')
    plt.title('prediction')
plt.show()

重建图像
我们可以看到,即使瓶颈层只有三个维度,网络也可以非常准确地重现输入,但是图像并不像预期的那样清晰,主要是因为瓶颈层中的节点数量过少。具有不同瓶颈层大小 (2351050) 的网络训练后,可视化重建的图像如下所示:

def train_aec(latent_dim):
    model = AutoEncoder(latent_dim).to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)

    num_epochs = 20
    train_loss_epochs = []
    val_loss_epochs = []

    for epoch in range(num_epochs):
        N = len(trn_dl)
        trn_loss = []
        val_loss = []
        for ix, (data, _) in enumerate(trn_dl):
            loss = train_batch(data, model, criterion, optimizer)
            pos = (epoch + (ix+1)/N)
            trn_loss.append(loss.item())
        train_loss_epochs.append(np.average(trn_loss))

        N = len(val_dl)
        trn_loss = []
        val_loss = []
        for ix, (data, _) in enumerate(val_dl):
            loss = validate_batch(data, model, criterion)
            pos = epoch + (1+ix)/N
            val_loss.append(loss.item())
        val_loss_epochs.append(np.average(val_loss))
    epochs = np.arange(num_epochs)+1
    plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
    plt.plot(epochs, val_loss_epochs, 'r-', label='Test loss')
    plt.title('Training and Test loss over increasing epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid('off')
    plt.show()
    return model

aecs = [train_aec(dim) for dim in [50, 2, 3, 5, 10]]

for _ in range(10):
    ix = np.random.randint(len(val_ds))
    im, _ = val_ds[ix]
    plt.subplot(1, len(aecs)+1, 1)
    plt.imshow(im[0].detach().cpu(), cmap='gray')
    plt.title('input')
    idx = 2
    for model in aecs:
        _im = model(im[None])[0]
        plt.subplot(1, len(aecs)+1, idx)
        plt.imshow(_im[0].detach().cpu(), cmap='gray')
        plt.title(f'prediction\nlatent-dim:{model.latend_dim}')
        idx += 1
plt.show()

请添加图片描述

随着瓶颈层中向量维度的增加,重建图像的清晰度逐渐提高。

小结

自编码器是一种无监督学习的神经网络模型,用于数据的特征提取和降维。它由编码器和解码器组成,通过将输入数据压缩到低维表示,并尝试重构出原始数据来实现特征提取和数据的降维。自编码器的训练过程中,目标是最小化输入数据与重构数据之间的重建误差,以使编码器捕捉到数据的关键特征。自编码器在无监督学习和深度学习中扮演着重要的角色,能够从数据中学习有用的特征,并为后续的机器学习任务提供支持。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1306459.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智慧燃气让城市能源系统高效运行

关键词:智慧燃气、燃气数字化、智慧燃气平台、智慧燃气解决方案、智慧燃气系统 随着我国城镇燃气行业的发展,燃气行业管理及服务从简单的手工运作阶段迈入数字燃气阶段,大量采用信息化手段管理燃气业务,智慧燃气应运而生。它既是…

互动时代的新趋势

数字人直播是指通过数字化技术将虚拟人物以真实的方式呈现在观众面前,并实现与观众的实时互动的一种新型娱乐形式。近年来,随着科技的发展和社交媒体的兴起,数字人直播越来越受到人们的关注,并成为当今互动时代的新趋势。 首先&a…

centos7x 安装支持gpu驱动的docker

1、卸载以前版本的驱动 sudo /usr/bin/nvidia-uninstall2、先安装基础项 yum install kernel kernel-devel gcc make -yyum install kernel kernel-devel gcc gcc-c make -y 3、禁用驱动源 nouveau echo "blacklist nouveau " >>/etc/modp…

移动CRM:智能化助力,销售更高效

商机稍纵即逝,及时把握机会才能促成更多交易。想要在出差途中也能掌握公司运营情况?得益于移动互联网的快速发展,现在您可以利用移动CRM客户管理系统,实时管理线索,查看销售数据。本文将向您介绍,移动CRM能…

热分析和报告软件 :FLIR Thermal Studio Pro Crack

FLIR 热工作室套件。FLIR Thermal Studio Pro FLIR Thermal Studio Suite 可帮助用户简化检查、组织数据并管理数千张热图像和视频——无论他们使用的是手持式热像仪、无人机系统 (UAS)、声学成像相机还是光学气体成像(OGI)相机。该订阅软件提供了关键组…

Qt提升绘制效率,绘制加速。

在我们绘制一些复杂逻辑且数据量巨大的图形时,经常会出现流畅性问题,本文就是来进行讲解如何提升绘制效率的。 实现思路: 场景1:我们绘制多个静态图形和绘制一张图片哪个更快。很明显绘制多个图形的时候要慢很多。所以我们将多个图…

Web漏洞分析-文件解析及上传(上)

随着互联网的迅速发展,网络安全问题变得日益复杂,而文件解析及上传漏洞成为攻击者们频繁攻击的热点之一。本文将深入研究文件解析及上传漏洞,通过对文件上传、Web容器IIS、命令执行、Nginx文件解析漏洞以及公猫任意文件上传等方面的细致分析&…

* demo、源码、桌面端软件

demo demohttps://bidding-m.gitee.io/maptalks-test-next/#/ 源码 源码https://gitee.com/bidding-M/maptalks-test-next 桌面端 桌面端https://gitee.com/bidding-M/map-collection/blob/master/apps/maptalks-win-0.0.1-x64.exe

中洺科技-数据标注创就业浪潮

平凡的事不能普通的做那么就有所不同。这也是我们在这波新的AI浪潮中做副业、创业时应该思考的。 正如一些业内从业者所说,数据标注行业本身的利润其实非常有限。如果现在招聘专职数据标注人员的成本太高,只做标注项目的公司就太难立足了,数据…

微信小程序(一) —— 常见组件

文章目录 🎀项目基本组成结构📢常见的视图容器类组件viewscroll-viewswiper和swiper-item使用viewscroll-viewswiper和swiper-itemswiper标签属性 🌶️常用的基础内容组件textrich-text 📮其他常用组件buttonimagenavigator &…

简易的JS逆向解码

在实战的漏洞挖掘中阅读JS有以下几个作用: 1.JS中存在插件名字,根据插件找到相应的漏洞直接使用 通过控制台大致阅读网站JS代码发现此网页引用了北京的一家公司的代码,并且使用了h-net的框架,接下来我们可以百度这家公司或者是这…

基于C/C++的rapidxml加载xml大文件 - 上部分翻译

RAPIDXML手册 版本 1.13 版权所有 (C) 2006, 2009 Marcin Kalicinski有关许可证信息,请参阅随附的文件许可证 .txt。 目录 1. 什么是 RapidXml? 1.1 依赖性和兼容性1.2 字符类型和编码1.3 错误处理1.4 内存分配1.5 …

HTML---基础

文章目录 目录 文章目录 前言 一.HTML概述 二.HTML相关概念 HTML作用域 HTML标签 HTML转译字符 总结 前言 一.HTML概述 HTML(超文本标记语言)是一种用于创建网络页面的标记语言。它以标记的形式编写,该标记描述了文档的结构和内容。HTML…

黑苹果之网卡篇

今天主要来聊一下黑苹果如何选择网卡 黑苹果对硬件的要求比较局限,可以用的无线网卡也比较少,而且很多也已经涨价,这里就结合自己的使用经验,简单分享几款比较好用的黑苹果无线网卡方案。 一、BCM94360系列 如果想要稳定、省心&am…

【K8S 系列】认识k8s、k8s架构

一、什么是k8s? Kubernetes 简称 k8s,是支持云原生部署的一个平台,k8s 本质上就是用来简化微服务的开发和部署的,用于自动化部署、扩展和管理容器化应用的开源容器编排技术。对于传统的docker其实也提供了容器编排的技术docker-compose&…

gdb本地调试版本移植至ARM-Linux系统

移植ncurses库 本文使用的ncurses版本为ncurses-5.9.tar.gz 下载地址:https://ftp.gnu.org/gnu/ncurses/ncurses-5.9.tar.gz 1. 将ncurses压缩包拷贝至Linux主机或使用wget命令下载并解压 tar-zxvf ncurses-5.9.tar.gz 2. 解压后进入到ncurses-5.9目录…

【人工智能Ⅰ】实验8:DBSCAN聚类实验

实验8 DBSCAN聚类实验 一、实验目的 学习DBSCAN算法基本原理,掌握算法针对不同形式数据如何进行模型输入,并结合可视化工具对最终聚类结果开展分析。 二、实验内容 1:使用DBSCAN算法对iris数据集进行聚类算法应用。 2:使用DBS…

论文阅读——GroupViT

GroupViT: Semantic Segmentation Emerges from Text Supervision 一、思想 把Transformer层分为多个组阶段grouping stages,每个stage通过自注意力机制学习一组tokens,然后使用学习到的组tokens通过分组模块Grouping Block融合相似的图片tokens。通过这…

CSPNet: A New Backbone that can Enhance Learning Capability of CNN(2019)

文章目录 -Abstract1 Introduction2 Related workformer work 3 Method3.1 Cross Stage Partial Network3.2 Exact Fusion Model 4 Experiments5 Conclusion 原文链接 源代码 - 梯度信息重用(有别于冗余的梯度信息)可以减少计算量和内存占用提高效率&am…

二、结合各种图形库实现各种demo(11-20)

demo地址https://bidding-m.gitee.io/maptalks-test-next/#/ 11、isects 12、right click menu 13、infoWindow 14、image marker 15、multi image marker 16、vector-marker-fill 17、line-gradient-arrow 18、rotated-marker-with-line 19、smoothness-line 20、polygon 欢迎…