深度学习——自编码器AutoEncoder

news2024/10/5 20:24:34

基本概念

概述

自编码器(Autoencoder)是一种无监督学习的神经网络模型,用于学习数据的低维表示。它由编码器(Encoder)和解码器(Decoder)两部分组成,通过将输入数据压缩到低维编码空间,再从编码空间中重构输入数据。

基本结构

自编码器的基本结构如下:
1.编码器(Encoder):接收输入数据,将其映射到低维编码空间。编码器由一系列隐藏层组成,通常逐渐减小维度以进行特征提取和数据压缩。
2.解码器(Decoder):接收编码器的输出,将编码后的数据映射回原始输入空间。解码器的结构与编码器相反,逐渐增加维度并尝试重构原始数据。
3.重构损失(Reconstruction Loss):自编码器的目标是尽可能准确地重构输入数据。因此,使用重构损失函数来衡量原始数据与重构数据之间的差异,如均方误差(MSE)或交叉熵损失。

训练过程

1.将输入数据提供给编码器,获得低维编码。
2.将编码结果传递给解码器,尝试重构输入数据。
3.计算重构损失,并通过反向传播优化网络参数,使重构误差最小化。
重复上述步骤,直到自编码器能够准确地重构输入数据。

应用

1.数据降维:自编码器可以学习数据的低维表示,有助于数据的压缩和降维。
2.特征学习:通过训练自编码器,可以学习到数据的有意义的特征表示,用于后续的监督学习任务。
3.异常检测:自编码器可以学习数据的正常分布,从而用于检测异常或异常数据的重构错误。

详细代码与注释

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np


# torch.manual_seed(1)    # reproducible

# Hyper Parameters
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005         # learning rate
DOWNLOAD_MNIST = True

N_TEST_IMG = 5
# Mnist digits dataset
train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,                                     # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=DOWNLOAD_MNIST,                        # download it if you don't have it
)

# plot one example
# 训练数据
print(train_data.train_data.size())     # (60000, 28, 28)
# 训练标签
print(train_data.train_labels.size())   # (60000)
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[2])
plt.show()

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),   # compress to 3 features which can be visualized in plt
        )
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),       # compress to a range (0, 1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded


autoencoder = AutoEncoder()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()

# initialize figure
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()   # continuously plot

# original data (first row) for viewing
view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
for i in range(N_TEST_IMG):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())

# 训练
for epoch in range(EPOCH):
    for step, (x, b_label) in enumerate(train_loader):
        b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)
        b_y = x.view(-1, 28*28)   # batch y, shape (batch, 28*28)

        encoded, decoded = autoencoder(b_x)

        # 比对解码出来的数据和原始数据,计算loss
        loss = loss_func(decoded, b_y)      # mean square error
        optimizer.zero_grad()               # clear gradients for this training step
        loss.backward()                     # backpropagation, compute gradients
        optimizer.step()                    # apply gradients

        if step % 100 == 0:
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())

            # plotting decoded image (second row)
            _, decoded_data = autoencoder(view_data)
            for i in range(N_TEST_IMG):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            plt.draw()
            plt.pause(0.05)

plt.ioff()
plt.show()

# visualize in 3D plot
view_data = train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.
encoded_data, _ = autoencoder(view_data)
fig = plt.figure(2)
ax = Axes3D(fig)
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values = train_data.train_labels[:200].numpy()
for x, y, z, s in zip(X, Y, Z, values):
    c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)
ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max())
plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/763272.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小程序使用云函数调用第三方api避免域名备案

小程序使用云函数调用第三方api避免域名备案 在小程序开发中,如果需要调用第三方 API,但由于域名备案的限制无法直接在小程序中使用,我们可以利用云函数来解决这个问题。 1. 开通云开发 使用微信开发者工具打开小程序项目,点击…

IDEA 提交git 之后撤回操作

方式一 1.选择提交记录; 2、 右键git然后选择drop commit; 弊端:会将修改的代码全部进行删除操作 打开 IDEA 的 本地历史记录功能,对修改的内容进行复原 方式二: 1、撤回commit 2、选择项目——>右击git——…

ROS-Qt-转CMake编译以及qmake第三方库添加及其他

Qt 开发ROS 界面的方法 方法2 带ui的工作空间配置(以ROS节点执行) 步骤1 $ mkdir catkin_qt $ cd catkin_qt $ mkdir src $ cd src $ catkin_init_workpasce $ cd .. $ catkin_make $ cd src $ catkin_create_qt_pkg ros_ui roscpp rospy std_msgs $ …

生成式AI时代,亚马逊云科技致力推动技术的普惠,让更多企业受益

当谈及AIGC时, 我们该谈些什么? 生成式AI技术与应用的不断发展,为各个行业都注入了全新的机会与活力。AIGC成为了今年最为激动人心的技术话题。亚马逊云科技也一马当先,在6月27-28日,2023亚马逊云科技中国峰会上分享…

堆的Top-K问题

⭐️ TOP-K问题 TOP-K问题:即求数据结合中前 k k k 个最大的元素或者最小的元素,一般情况数据量都比较大。 比如:专业前10名、世界500强、富豪榜、游戏中前100的活跃玩家等。 如果数据量过大,排序的方式就不太可取了。 思路&…

TikTok引流的新玩法,用AdsPower打开新大陆!

做跨境电商的人都知道,天上不会掉流量,想要产品火,就必须要引流。大佬的一个产品动不动就几千几万的观看量,可是自己的产品却无人问津。他们到底怎么做到的呢?很简单,注册AdsPower和TikTok的账号&#xff0…

el-input输入框的那些事

vue3element-plustses6 此帖只为记录开发中遇到的需求,技术问题,坑 1、文本域禁止自由拉伸 1、文本域禁止自由拉伸 el-input有一个枚举类型的resize属性,控制拉伸,‘none’ | ‘both’ | ‘horizontal’ | ‘vertical’&#xf…

找实拍高清视频网站,我推荐这6个

本期推荐6个高清视频素材网站,视频剪辑、自媒体必备,建议收藏~ 菜鸟图库 https://www.sucai999.com/video.html?vNTYwNDUx 菜鸟图库虽然是一个设计网站,但它还有非常丰富的视频和音频素材,视频素材全部都是高清无水印&#xff0…

暑期代码每日一练Day1:Leetcode415. 字符串相加

题目 415. 字符串相加 分析 题目意思是给你两个纯数字 字符串(表示的是一个有意义的正整数),让你算出这两个字符串表示的数字的和,最后返回以字符串表示的结果,其中的过程不能直接将初始给定的两个字符串直接转化为…

ORB+FLANN

FLANN 代表 近似最近邻的快速库。它包含针对大型数据集中的快速最近邻搜索和高维特征优化的算法集合。对于大型数据集,它比BFMatcher工作得更快。 对于基于 FLANN 的匹配器,我们需要传递两个字典,指定要使用的算法、相关参数等。第一个是Ind…

2024考研408-操作系统 第三章-内存管理 学习笔记

文章目录 一、内存管理基础1.1、内存的基础知识1.1.1、什么是内存?有何作用?1.1.2、进程运行的基本原理1.1.2.1、指令的工作原理1.1.2.2、理解逻辑地址与物理地址1.1.2.3、从写程序到程序运行1.1.2.4、三种链接方式(静态、转入时动态、运行时…

openssl源码编译输出库-guidance-傻瓜式教程

目标: 下载openssl源码 编译输出目标版本,例如使用Android NDK编译输出Android使用的32位的库 1、下载源码 git clone https://github.com/openssl/openssl.git -b openssl-3.0.9 2、 请下载Linux版本的Android NDK 请下载Linux版本的Android NDK, 并完…

没看完这篇文章,别说你会用Ping

中午好,我的网工朋友。 网工生活里每天都和ping打交道,ping来ping去,很多人知道ping,却不知道怎么把ping用出更多花样出来。 今天,我特地给你关于ping命令的使用大全,在更多不同的项目场景里,…

解决访问127.0.0.1时,提示“127.0.0.1 拒绝了我们的连接请求”

目录 问题描述 解决方案 问题描述 我电脑是win10系统,刚刚在访问http://127.0.0.1时,浏览器显示“127.0.0.1 拒绝了我们的连接请求”,为何访问本机IP显示拒绝访问? 解决方案 1. windows徽标 I 打开设置,选择“应用…

3ds max高级教程:创建带有骨骼动画的机器人模型

推荐: NSDT场景编辑器助你快速搭建可二次开发的3D应用场景 然而,下面我们示例机器人腿的第一个版本不是很现实,因为它会像没有肌肉的骨骼结构一样坍塌。在第二个版本中,我们将添加一些机器人“肌肉”,第三个版本将包括…

分布式运用存储系统Ceph

一、ceph的相关知识 1.ceph介绍与简介 Ceph是一个开源的分布式存储解决方案,旨在提供可扩展性、高性能和强大的数据可靠性。它采用了一种分布式对象存储架构,能够同时提供块存储和文件存储的功能。 Ceph使用C语言开发,是一个开放、自我修复和…

Makefile文件编写

文章目录 格式自动检查更新效率变量模式匹配函数clean 格式 目标:依赖 tab 命令 自动检查更新 当有文件发生修改后,重新make会自动对发生修改的依赖进行编译 效率 由于在 make时会进行检查更新,对于有修改的依赖会重新编译,为…

定位理论:引领企业变革的幕后推手

在商业的海洋中,如何能让你的企业像一座明亮的灯塔,独特而引人注目?这就需要我们掌握一种强大的工具——定位理论。那么,定位理论究竟是什么?我们为什么要学习它?它如何能为我们的企业创造价值?今天,让我们一起深入探索定位理…

vue3组件引用使用的坑

今天准备用el-tabs写个页面,发现点击后组件怎么都显示不了,后来才发现是组件引用的原因 这是页面是显示的效果: 乍一看确实是对的。。。 但是当我点击完这三个tab后再重新点击道路管理后,有意思的出现了: one组件消失…

UE4/5AI制作基础AI(适合新手入门,运用黑板,行为树,ai控制器,角色类,任务)

目录 制作流程 第一步:创建资产 然后创建一个AIController 之后创建一个黑板和行为树: 第二步:制作 黑板 行为树 任务 运行行为树 结果 制作流程 第一步:创建资产 第一步直接复制你的人物蓝图,做一个npc&…