强化学习代码规划之深度学习预备

news2025/1/12 3:56:27

现在到了自动编码器和解码器,同样,先练几遍代码,再去理解

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np


# torch.manual_seed(1)    # reproducible

# Hyper Parameters
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005         # learning rate
DOWNLOAD_MNIST = False
N_TEST_IMG = 5

# Mnist digits dataset
train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,                                     # this is training data
    transform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=DOWNLOAD_MNIST,                        # download it if you don't have it
)

# plot one example
print(train_data.train_data.size())     # (60000, 28, 28)
print(train_data.train_labels.size())   # (60000)
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[2])
plt.show()

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 12),
            nn.Tanh(),
            nn.Linear(12, 3),   # compress to 3 features which can be visualized in plt
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),       # compress to a range (0, 1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded


autoencoder = AutoEncoder()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()

# initialize figure
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()   # continuously plot

# original data (first row) for viewing
view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
for i in range(N_TEST_IMG):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())

for epoch in range(EPOCH):
    for step, (x, b_label) in enumerate(train_loader):
        b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)
        b_y = x.view(-1, 28*28)   # batch y, shape (batch, 28*28)

        encoded, decoded = autoencoder(b_x)

        loss = loss_func(decoded, b_y)      # mean square error
        optimizer.zero_grad()               # clear gradients for this training step
        loss.backward()                     # backpropagation, compute gradients
        optimizer.step()                    # apply gradients

        if step % 100 == 0:
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())

            # plotting decoded image (second row)
            _, decoded_data = autoencoder(view_data)
            for i in range(N_TEST_IMG):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')
                a[1][i].set_xticks(()); a[1][i].set_yticks(())
            plt.draw(); plt.pause(0.05)

plt.ioff()
plt.show()

# visualize in 3D plot
view_data = train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.
encoded_data, _ = autoencoder(view_data)
fig = plt.figure(2); ax = Axes3D(fig)
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values = train_data.train_labels[:200].numpy()
for x, y, z, s in zip(X, Y, Z, values):
    c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)
ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max())
plt.show()

好,当练习了几遍代码后,对整体有点了解了,然后再理解每一步是什么作用。

问:编码器和解码器设计的作用是什么?
答:编码器和解码器在AutoEncoder中的设计有以下作用:

编码器(Encoder):编码器负责将输入数据映射到一个低维度的表示,也称为编码或隐藏层。通过降低输入数据的维度,编码器可以提取出数据的关键特征,并将其压缩成更紧凑的表示形式。这种压缩表示通常具有良好的特征表达能力,可以用于后续任务,如降维、特征提取或数据可视化。

解码器(Decoder):解码器负责从编码器获得的低维度表示中重构原始输入数据。它将压缩表示解码成与原始输入具有相同维度的输出。解码器的目标是尽可能准确地重建原始数据,以便最小化重构误差。通过解码器的工作,AutoEncoder可以通过低维表示重新生成与原始数据相似的数据,实现数据的重建或还原。

编码器和解码器的结合使得AutoEncoder能够学习到数据的压缩表示并用于重构,通过这种方式实现了一种自监督学习的方式。在训练过程中,AutoEncoder的目标是最小化重构误差,使得解码器能够尽可能准确地重建原始输入,同时编码器能够学习到数据的有用特征。

再问:这样降维然后又增加维度,不就是等于什么也没做吗?
再答:实际上,AutoEncoder的目标并不仅仅是将数据降维然后再增加维度,而是通过压缩和解压缩过程来学习数据的有用表示。

虽然编码器将数据压缩到较低维度的表示,但这个表示通常不仅仅是原始数据的简单副本。通过限制编码器的容量,AutoEncoder被迫学习到数据的最重要特征,并忽略一些冗余和噪音。因此,编码器产生的低维表示往往是数据的更高级别的抽象。

解码器的作用是将这个抽象的低维表示解压缩回原始数据空间。尽管维度增加,但解码器在这个过程中通过学习从低维表示中恢复原始数据的映射来捕捉数据的结构和模式。解码器的目标是最小化重构误差,以尽量准确地还原原始数据。

因此,AutoEncoder通过编码器和解码器的组合,可以学习到数据的紧凑表示,该表示捕捉了数据中的关键特征。这种低维表示可以用于诸如数据压缩、特征提取、数据可视化和异常检测等任务。虽然编码器和解码器都涉及维度的变化,但它们的目的是通过学习有用的数据表示来提供更高级别的抽象和数据重建,而不仅仅是简单地降低和增加维度。

打印出来了3D的分类图片
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/545486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文心一言和ChatGPT最全对比

文心一言和ChatGPT都是基于深度学习技术的自然语言处理模型,有各自的优势和使用场景,无法简单地比较 ChatGPT 和文心一言哪一个功能更强大,它们各自具有优势和局限性,需要根据具体需求进行选择,以下一些具体对比&#…

HttpRunner 接口自动化测试进阶

这里我们介绍一下通过调试源码的方式来做接口测试: 1、获取源码 github下载或pycharm的git checkout功能去获取,这里直接下载 通过ide打开解压后的源码包 源码包含三层目录结构:docs帮助文档、httprunner核心库、tests自测文件 2、源码结…

小白量化《穿云箭集群量化》(8) 实盘核聚变氢弹策略

小白量化《穿云箭集群量化》(8) 核聚变策略 上一篇介绍了超级订单SuperOrder功能在股票上买入策略,这篇介绍MetaTrader5期货外汇的双向交易策略。 交易策略比较有名的是马丁策略,马丁策略是单向策略。 我们设计了双向策略原子弹策…

新来的实习生太牛了,还是我们太弱了?...

前几天有个朋友向我哭诉,说她在公司工作(软件测试)了7年了,却被一个实习生代替了,该何去何从? 这是一个值得深思的问题,作为职场人员,我们确实该思考,我们的工作会被实习…

Qt布局管理器

一、布局管理器 1.1、布局管理器的作用 布局管理器是摆放控件的辅助工具,主要解决组件的位置和大小无法自适应父窗口变化的问题,主要功能如下: 自动调整控件的位置,包括控件之间的间距、对齐等当用户调整窗口大小时,位…

【洛谷】P1404 平均数

【洛谷】P1404 平均数 题目描述 给一个长度为 n n n 的数列,我们需要找出该数列的一个子串,使得子串平均数最大化,并且子串长度 ≥ m \ge m ≥m。 输入格式 第一行两个整数 n n n 和 m m m。 接下来 n n n 行,每行一个整数 …

激光点云3D目标检测算法之CenterPoint

激光点云3D目标检测算法之CenterPoint 本文首发于公众号【DeepDriving】,欢迎关注。 前言 CenterPoint是CVPR 2021的论文《Center-based 3D Object Detection and Tracking》中提出的一个激光点云3D目标检测与跟踪算法框架,与以往算法不同的是&#xff…

一大波特斯拉人形机器人上线,马斯克震撼官宣2款新车!

来源 | 新智源 ID | AI-era 【导读】这次特斯拉股东日,虽没有新车,但马斯克确定Cybertruck今年一定会来。 特斯拉股东日,依旧没有新车。 万众瞩目的马斯克登台继续画饼,「我不官宣新车,不过新车年销量会超过500万」…

【云原生】k8sPod基础概念

k8sPod基础概念 一、Pod概述1、pod概念2、Pod资源限制 二、Pod的两种使用方式三、资源共享1、创建Pod的方式2、Pod功能 四、底层容器Pause1、Pause共享资源2、Pause主要功能3、Pod与Pause结构设计用意 五、镜像的拉取策略1、Pod容器镜像拉取策略2、Pod重启策略 六、容器的分类1…

信息收集-ip地址

1、cdn网络 CDN(Content Delivery Network)网络是一种分布式网络架构。它通过将内容(如网页、图片、视频等)缓存到公共的服务器上,以便更快速、更可靠地交付给用户所在的位置。CDN解决了Internet中的一些固有瓶颈和性…

【ROS】Ubuntu22.04安装ROS2(Humble Hawksbill)

0、版本说明 Ubuntu22.04对应的ROS2的版本为Humble Hawksbill(ros-humble) 如果不是在Ubuntu22.04中安装ROS,请参考下面Ubuntu和ROS的版本对应关系 1、更新apt包列表 $ sudo apt update2、设置编码 将ubuntu环境语言编码设置为en_US en_…

appium自动化测试实战详解及框架搭建

目录 一、Appium的介绍 二、Appium实战(以真机为例) 一、Appium的介绍 Appium是一款开源的自动化测试工具,其支持iOS和安卓平台上的原生的,基于移动浏览器的,混合的应用。 1、 使用appium进行自动化测试的好…

【C++】容器篇(一)—— vector 的基本概述以及模拟实现

前言: 在之前,我们已经对 string类进行了基本的概述,并且手动的实现了string类中常用的接口函数。本期,我将带领大家学习的是STL库中的一个容器 -- vector 的学习。相比于之前的string类,本期的 vector 相对来说实现起…

CSAPP复习(三)

CH1: 但是如果问什么时候 得到二进制文件 那就是汇编的时候 CH2 sizeof 的返回值是一个无符号数 然后i-D自动默认是一个无符号数 无符号数不能出现负数 所以出现了一个循环 所以永远不可能循环完成。 CH7链接 什么是静态库?使用静态库的优点是什么? …

在windows内使用virtualbox搭建安卓x86--以及所遇到的问题解决--3

一.ARM兼容包的植入 1.下载arm包: 2.安装arm兼容包 3.验证arm兼容包是否移植成功 二.触屏无效 三.玩游戏卡顿严重 一.ARM兼容包的植入 在AndroidX86系统内大部分应用(国内)并没有适配X86架构,安装基于arm架构的应用会出现报错的情况,如遇到此问题可…

【Linux网络】Linux防火墙

Linux防火墙 一 、Linux包过滤防火墙概述1.1iptables概述1.2netfitel与iptables的关系1.3四表五链1.3.1 四表1.3.2五链1.3.3数据包到达防火墙时,规则表之间的优先顺序**1.3.4规则链之间的匹配顺序** 二、iptables的安装与配置方法2.1iptables的安装2.2iptables的配置…

MySQL数据库基础3-基础查询

文章目录 基础查询(单表)替换查询结果排序筛选分页结果更新表删除数据截断(清空)表聚合函数分组查询 基础查询(单表) 创建表、单行插入、多行插入就不重复介绍了。 替换 当我们的程序每天都会产生大量的数据,而这些数据都是前一天或者再之前的数据更新产生&#…

0基础学习VR全景 平台篇第26章:热点功能-3D物体/空间模型

大家好,欢迎观看蛙色VR官方系列——后台使用课程! 本期为大家带来蛙色VR平台,热点功能—3D物体/空间模型操作。 热点,指在全景作品中添加各种类型图标的按钮,引导用户通过按钮产生更多的交互,增加用户的多…

opencv_c++学习(十四)

一、图像直方图的统计与绘制 如果直方图各个数字之间分布比较均衡,则其对比度就相对较低,反之亦然。同时也可以通过直方图可以大致了解图像的亮暗等。 calcHist(const Mat * images, int nimages, const int * channels, lnputArray mask, OutputArray…

【Python Xpath】零基础也能轻松掌握的学习路线与参考资料

Python是一种面向对象的编程语言。Xpath是一种在XML文档中定位信息的方法。XPath是一种语言,可以用于xml和html文档中选择和查找节点。在Python中,我们可以使用xpath来解析html页面,从而提取所需的数据。 Python xpath学习路线: …