深度学习5 -- 循环神经网络(代码实现篇+付详细流程文件)

news2025/2/10 9:20:30

引言

  • 本文是使用pytorch对循环神经网络RNN(Recurrent Neural Network)的代码实现,作为之前介绍RNN原理的一个代码补充。

  • RNN原理介绍

  • 本文代码相关介绍相对较为详细,也为自己的一个学习过程,对RNN的理解还是比较浅显,有错误的地方欢迎指正。

简述RNN结构

  • 详细原理介绍可以参考上述链接,此处简述RNN结构实为方便理解后面代码分析部分。
  • 单向循环神经网络
    • RNN的应用场景一般是当前输入与前一个输入是有联系的,所以下图x部分的参数会与X_t-1有关
      • x:数据输入
      • u:输入层到隐藏层的权重
      • s:隐藏层的输出结果
      • v:隐藏层到输出层的权重
      • w:上一次的值S_t-1作为这一次输入的权重矩阵
      • 关于数学计算公式
  • 双向循环神经网络
    • 单项循环神经网络只能作为与前一个数据建立连接,双向循环神经网络则可以同后一个数据建立连接
    • 当然这个隐藏层只有一层,也可以多加几层构成深度循环神经网络

RNN实例

  • 关于实现RNN的实例,我觉得有一个比较简单但是又比较符合RNN使用场景的序列数据例子,那就是正弦和余弦函数。
    • 该例子来自参考资料1
    • 以sin函数值作为输入,其对应的cos函数值作为输出,在相同sin值的情况下会对应不同的cos值的情况,这就是因为输出结果不仅要看输入数据,还要依赖前后值的信息,且FC,CNN就不适合该例子了。

Pytorch中RNN函数torch.nn.RNN()参数介绍

  • 参数其实主要写input_size和hidden_size,其他的参数使用默认的即可,当然有特殊需要再设置

  • 其重要参数解析如下

    参数含义
    input_size输入RNN的维度/输入x的特征数量
    hidden_size隐藏层节点数量/隐藏层的特征数量
    num_layersRNN的层数
    nonlinearity指定激活函数使用tanh还是relu。默认是tanh
    bias如果是 False , 那么 RNN 层就不会使用偏置权重 b_ih 和 b_hh, 默认: True
    batch_first如果 True, 那么输入 Tensor 的 shape 应该是 (batch, seq, feature),并且输出也是一样(详见参3)
    dropout如果值非零, 那么除了最后一层外, 其它层的输出都会套上一个 dropout 层
    bidirectional如果 True , 将会变成一个双向 RNN, 默认为 False
  • 主要是介绍了节点数的一些细节 参考资料2

  • 主要介绍了数据维度的一些细节,介绍的挺详细的,可惜我还是没看懂,以后看懂了补上参考资料3

  • 贴两个询问ChatGPT的截图,作为比对
    ChatGPT3.5
    在这里ChatGPT4插入图片描述

RNN代码分析

  • 具体print输出结果以及打印出来的结果见文章末尾的测试文件资源)
  • 测试文件中还包含几个函数测试的结果,用来辅助分析代码
  • 注:个人分析的注释可能也有错误,仅供参考,且数据维度的变换那里目前理解能力有限,思考了很久还是一知半解,待后续有能力完全解析再做补充。
  • 该代码是他人写好的代码,并不是本人的实现代码,主要是确实个人能力不足目前难以自己写出来…
# encoding:utf-8
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn

# 定义RNN模型(可以类别下方RNN简单测试代码理解)
class Rnn(nn.Module):
    def __init__(self, input_size):
        super(Rnn, self).__init__()
        # 定义RNN网络
        ## hidden_size是自己设置的,貌似取值都是32,64,128这样来取值
        ## num_layers是隐藏层数量,超过2层那就是深度循环神经网络了
        self.rnn = nn.RNN(
                input_size=input_size,
                hidden_size=32,
                num_layers=1,
                batch_first=True  # 输入形状为[批量大小, 数据序列长度, 特征维度]
                )
        # 定义全连接层
        self.out = nn.Linear(32, 1)

    # 定义前向传播函数
    def forward(self, x, h_0):
        r_out, h_n = self.rnn(x, h_0)
        # print("数据输出结果;隐藏层数据结果", r_out, h_n)
        # print("r_out.size(), h_n.size()", r_out.size(), h_n.size())
        outs = []
        # r_out.size=[1,10,32]即将一个长度为10的序列的每个元素都映射到隐藏层上
        for time in range(r_out.size(1)):  
            # print("映射", r_out[:, time, :])
            # 依次抽取序列中每个单词,将之通过全连接层并输出.r_out[:, 0, :].size()=[1,32] -> [1,1]
            outs.append(self.out(r_out[:, time, :])) 
            # print("outs", outs)
        # stack函数在dim=1上叠加:10*[1,1] -> [1,10,1] 同时h_n已经被更新
        return torch.stack(outs, dim=1), h_n 

TIME_STEP = 10
INPUT_SIZE = 1
LR = 0.02
model = Rnn(INPUT_SIZE)
print(model)

# 此处使用的是均方误差损失
loss_func = nn.MSELoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

h_state = None  # 初始化h_state为None

for step in range(300):
    # 人工生成输入和输出,输入x.size=[1,10,1],输出y.size=[1,10,1]
    start, end = step * np.pi, (step + 1)*np.pi
    # np.linspace生成一个指定大小,指定数据区间的均匀分布序列,TIME_STEP是生成数量
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32) 
    # print("steps", steps)
    x_np = np.sin(steps)
    y_np = np.cos(steps)
    # print("x_np,y_np", x_np, y_np)
    # 从numpy.ndarray创建一个张量 np.newaxis增加新的维度
    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
    # print("x,y", x,y)

    # 将x通过网络,长度为10的序列通过网络得到最终隐藏层状态h_state和长度为10的输出prediction:[1,10,1]
    prediction, h_state = model(x, h_state)
    h_state = h_state.data  
    # 这一步只取了h_state.data.因为h_state包含.data和.grad 舍弃了梯度
    # print("precision, h_state.data", prediction, h_state)
    # print("prediction.size(), h_state.size()", prediction.size(), h_state.size())
    
    # 反向传播
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    # 更新优化器参数
    optimizer.step()

# 对最后一次的结果作图查看网络的预测效果
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
plt.show()

另外一种构建RNN的方法

  • 除了上面介绍的torch.nn.RNN()外,还有RNNCell方法,使用方法如下:
  • cell = torch.nn.RNNCell(input_size=input_size, hidden_size=hidden_size)
  • hidden = cell(input, hidden)
    • input_size与hidden_size与上面的参数介绍的含义是一致的,但是输入类型不一样。
    • 和RNN不同的是RNN cell要自己写处理序列的循环,个人通俗理解就是,比如要处理3个句子,每个句子10个单词,每个单词用20长度的向量表示,如果使用nn.RNN(),那输出的tensor的shape应该是[10,3,20],而使用nn.RNNCell()需要将序列上的每个时刻分开处理,即送入的tensor的shape是[3,100],然后将该单元运行10次,灵活的代价当然就是比较麻烦。
    • 关于hidden_size输入shape与input_size同样有变化,概括来说就是(可以看上面代码的具体解析那有关于维度的详细分析)
    • input:batch_size*input_size
    • 输入的hidden: batch_size*hidden_size
    • 输出的hidden: batch_size*hidden_size

其他参考资料

  • 数据集为MNIST使用RNN实现(未测试,找资料时候看到了仅供参考)

M1

M2

  • RNN两种实现方式的区别以及代码实现比较

RR1

RR2

  • 本文代码详细测试过程文件(待上传)
    • Github代码文件资源

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/646035.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

迪赛智慧数——饼图(环形饼图):哪个年龄段的人最爱存钱?

效果图 50岁到60岁是存钱黄金10年,你存下钱了吗? 据央行发布的2022年金融统计数据报告显示,全年人民币存款额增加26.26万亿元,其中住户存款增加17.84万亿,刷新历史记录。 2022年在全国2200名40岁以下的人中,90后这一职场主力军…

电压放大器在无损检测中的作用和应用有哪些

电压放大器在无损检测中扮演着重要的角色,可以帮助实现信号的放大和增强,从而提高检测的灵敏度和准确性。下面,我们将详细探讨电压放大器在无损检测中的作用和应用。 图:ATA-2000系列高压放大器 电压放大器是一种用于放大电压信号…

Flink SQL之Interval Joins

1.Interval Joins(区间Join) 区间是双流join的优化,基于处理时间或事件时间,在一定时间区间内数据,相同的key进行join(支持 Batch\Streaming)。Interval Join 可以让一条流去 Join 另一条流中前…

BFT 最前线|北京智源发布悟道3.0大模型;马克龙会见Meta谷歌人工智能专家;马斯克:特斯拉市值未来将超过苹果与沙特阿美总和

文 | BFT机器人 AI视界 TECHNOLOGY NEWS 01 天垓100完成百亿级参数大模型训练 在第五届智源大会AI系统分论坛上,上海天数智芯半导体有限公司对外宣布,在天垓100加速卡的算力集群,基于北京智源人工智能研究院70亿参数的Aquila语言基础模型&am…

flink + Atlas 任务数据血缘调通

据此修改 Flink 源码 版本Flink1.13.5Atlas1.2.0 将 atlas 配置文件打进 flink-bridge;atlas 相关的 jar 放进 flink/lib jar uf flink-bridge-1.2.0.jar atlas-application.properties flink-conf.yaml 注册监听 org.apache.flink.configuration.ExecutionOpti…

6月第2周榜单丨飞瓜数据B站UP主排行榜(哔哩哔哩)发布!

飞瓜轻数发布2023年6月5日-6月11日飞瓜数据UP主排行榜(B站平台),通过充电数、涨粉数、成长指数三个维度来体现UP主账号成长的情况,为用户提供B站号综合价值的数据参考,根据UP主成长情况用户能够快速找到运营能力强的B站…

51、C++ 学习笔记

1、引用类型 引用类型是C引入的新类型,根据汇编的知识进行理解,程序在汇编后,变量名将失去意义,因为汇编码将替换成用内存的(链接地址or运行地址)访问变量。在C/C语言中,用变量名表示变量所占的那块内存,为…

仓储管理小程序开发 实现不同行业不同规模的仓管需求

在电子商务快速发展的时代,仓库管理对于一个企业的经营发展来说至关重要。如今互联网技术深入发展,很多企业都开发了信息化管理系统,仓库管理APP小程序就是企业结合自身的运算法则开发的一款线上应用软件,通过智能智慧仓库内人、物…

网络安全是一个好的专业吗?高考之后怎么选择?

目录 一.始于大学 二.一路成长 三. 如何学习网络安全 学前感言 零基础入门 尾言 本人信息安全专业毕业,在甲方互联网大厂安全部与安全乙方大厂都工作过,有一些经验可以供对安全行业感兴趣的人参考。 或许是因为韩商言让更多人知道了CTF&#xff0…

linuxOPS基础_LAMP开源项目实战

LAMP概述 LAMP:Linux Apache MySQL PHP LAMP 架构(组合) LNMP:Linux Nginx MySQL php-fpm LNMP 架构(组合) LNMPA:Linux Nginx(80) MySQL PHP Apache Nginx 代理方式 Apache&#…

Markdown编辑器使用

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

这所Top3顶尖院校,专业课太简单了,比双非还要简单!

一、学校及专业介绍 复旦大学(FDU,简称旦旦),除清北之外的顶尖学府,想必不用我过多介绍,Top3之一(众所周知,Top3有好多所图片,但我心目中的Top3永远是上海交大图片&#…

element-ui中表头添加自定义按钮以及其他自定义展示

可以使用&#xff1a;render-header方法即可 添加一个按钮如下&#xff1a; renderHeader (h) { return ( <div> <span>操作</span> <el-button type"primary" style"margin-left:90px" size"small" icon"el-icon-pl…

在测试外包干了4年,我废了...

外包公司值不值得去&#xff0c;是很多同行关心的话题。在职场一直流传着“外包不被当人看”“外包没有归属感”的言论。 客观来看&#xff0c;外包岗位确实存在一些缺点&#xff0c;比如&#xff1a;公积金&#xff0c;社保缴纳基数低&#xff0c;没有稳定的涨薪通道&#xff…

登录时token的存储

1.token是什么&#xff1f; 是一种身份的标识,比如我们入住一家酒店,他会给我们一张房卡,房卡的期限是有时间限制的,只有持有房卡的人才能入住酒店。 2.jsCookie 使用的方法 下包: npm i jscookie 导入: import Cookiejs from "js-cookie"; 使用: Cookie.js.set…

object类clone、finalize

2 什么是API API&#xff08;Application Programming Interface&#xff0c;应用程序接口&#xff09;是一些预先定义的函数。目的是提供应用程序与开发人员基于某软件可以访问的一些功能集&#xff0c;但又无需访问源码或理解内部工作机制的细节. API是一种通用功能集,有时公…

HTB-OnlyForYou

HTB-OnlyForYou 信息收集立足johnjohn -> root 信息收集 Designed by BootstrapMade. 在他们的TEAM的常见问答里面发现了一个beta产品。 网站首页可以下载疑似源码的文件。 右上角还有两个功能。 一个是上传图片并调整大小。 上传了文件后会跳转到list&#xff0c;选择…

【CV大模型SAM(Segment-Anything)】如何一键分割图片中所有对象?并对不同分割对象进行保存?

之前的文章【CV大模型SAM&#xff08;Segment-Anything&#xff09;】真是太强大了&#xff0c;分割一切的SAM大模型使用方法:可通过不同的提示得到想要的分割目标,中详细介绍了大模型SAM&#xff08;Segment-Anything&#xff09;根据不同的提示方式得到不同的目标分割结果。 …

11. 100ASK-V853-PRO开发板 RGB屏测试指南

100ASK-V853-PRO开发板 RGB屏测试指南 硬件要求&#xff1a; 100ASK-V853-PRO开发板七寸RGB屏 软件要求&#xff1a; 固件下载地址&#xff1a;链接&#xff1a;百度网盘 提取码&#xff1a;sp6a 固件位于资料光盘中的10_测试镜像/1.测试七寸RGB屏/v853_linux_100ask_uart0.…

echarts中国地图使用整理

一、echarts中国地图使用案例 1.准备地图数据china.json ; 需要的添加微信&#xff1a;tianma104&#xff0c;我发你 2.引入jquery&#xff0c;引入eachars 库 <script src"http://xx/ajax/libs/jquery/3.5.1/jquery.min.js"></script> <script s…