深度学习 RNN循环神经网络原理与Pytorch正余弦值预测

news2024/12/25 22:33:26

深度学习 RNN循环神经网络原理与Pytorch正余弦值预测

  • 一、前言
  • 二、序列模型
  • 三、不含序列关联的神经网络
  • 四、包含隐藏状态的卷积神经网络
  • 五、正余弦预测实战
  • 六、参考资料

一、前言

前面我们学习了前馈神经网络、卷积神经网络,它们有一个特点,就是每次输出跟上一次结果没有关联。但在一个句子中,每个词的顺序搭配是存在一定联系的,这个时候我们就需要考虑上一次提取的特征对本次输出的影响。这就是我们今天要学的循环神经网络(RNN),也叫递归神经网络,RNN被广泛地应用于自然语言处理(NLP)等领域。

二、序列模型

我们来看一个例子:

我昨天上学迟到了,老师批评了____。

空格里这个词最有可能是『我』,而不太可能是『小明』,甚至是『吃饭』。

这是由上下文推导出来的,这种输出与上下文相互关联的模型,叫做序列模型。

序列模型能够应用在许多领域,例如:

  • 语音识别
  • 音乐发生器
  • 情感分类
  • DNA序列分析
  • 机器翻译
  • 视频动作识别
  • 命名实体识别

三、不含序列关联的神经网络

在这里插入图片描述
为简化描述,我们不考虑偏置 b b b,如上图所示是包含一个隐藏层的神经网络。 X X X表示输入、 O O O表示输出; U U U是输入层到隐藏层的权重矩阵, V V V是隐藏层到输出层的权重矩阵。

设隐藏层的激活函数为 f f f、输出层的激活函数为 g g g,则有:
H = f ( U X ) O = g ( V H ) H=f(UX) \\O=g(VH) H=f(UX)O=g(VH)

四、包含隐藏状态的卷积神经网络

隐藏层的作用,其实就是对输入进行特征值提取,比如卷积神经网络中的卷积层就是对图像边缘的提取。如果说上一次的特征,会对本次特征提取造成一定影响,那怎么表示呢?

我们引入权重参数 W W W H t − 1 H_{t-1} Ht1表示上次特征,用 W H t − 1 WH_{t-1} WHt1表示上次特征对本次的影响程度。那么就有本次特征 H t = f ( U X t + W H t − 1 ) H_t=f(UX_t+WH_{t-1}) Ht=f(UXt+WHt1)
本次特征的值不仅取决于本次输入 X t X_t Xt,还受上次特征 H t − 1 H_{t-1} Ht1的影响。

这就是RNN的算法思想,用下图表示:
在这里插入图片描述

五、正余弦预测实战

import torch
import torch.nn as nn
import numpy as np
np.set_printoptions(suppress=True) #numpy不使用科学计数法

steps=1000   #迭代次数
learning_rate=0.01  #学习率
time_step=10    #步数大小
input_size=1    #输入特征数量
hidden_size=32  #隐藏层特征数量


class MyModel(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.rnn=nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True
        )
        self.out=nn.Linear(hidden_size, 1)

    def forward(self,x,h_state):
        r_out,h_state=self.rnn(x,h_state)

        outs = []
        for time_step in range(r_out.size(1)):    # 计算每个时间步的输出
            outs.append(self.out(r_out[:, time_step, :]))
        
        return torch.stack(outs, dim=1), h_state

plt_steps=[]
plt_loss=[]

h_state = None

model=MyModel()
#损失函数
cost=nn.MSELoss()
#迭代优化器
optmizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

step_now,step_x,sin_y,cos_y=None,None,None,None 
for step in range(steps):
    step_now=step
    step_x=np.linspace(step*np.pi,(step+1)*np.pi,time_step,dtype=np.float32) #起始值、结束值、个数
    sin_y=np.sin(step_x)
    cos_y=np.cos(step_x)

    x = torch.from_numpy(sin_y[np.newaxis, :, np.newaxis])    # shape (batch, time_step, input_size)
    y = torch.from_numpy(cos_y[np.newaxis, :, np.newaxis])

    pre_y,h_state=model(x,h_state)

    h_state = h_state.data

    #计算损失值
    loss=cost(pre_y,y)

    #在反向传播前先把梯度清零
    optmizer.zero_grad()

    #反向传播,计算各参数对于损失loss的梯度
    loss.backward()

    #根据刚刚反向传播得到的梯度更新模型参数
    optmizer.step()

    plt_steps.append(step)
    plt_loss.append(loss.item())

    #打印损失值
    if step%100==0:
        print('step:',step,'loss:',loss.item())
    
#绘制迭代次数与损失函数的关系
import matplotlib.pyplot as plt
plt.plot(plt_steps,plt_loss)

运行结果:

step: 0 loss: 0.5253313779830933
step: 100 loss: 0.1194605678319931
step: 200 loss: 0.0004494489112403244
step: 300 loss: 0.0004530779551714659
step: 400 loss: 0.00045654349378310144
step: 500 loss: 0.00045824996777810156
step: 600 loss: 0.00045904534636065364
step: 700 loss: 0.0004583548288792372
step: 800 loss: 0.00045726861571893096
step: 900 loss: 0.00045428838348016143

在这里插入图片描述
预测下一段数据结果:

step_x=np.linspace((step_now+1)*np.pi,(step_now+2)*np.pi,time_step,dtype=np.float32) #起始值、结束值、个数
sin_y=np.sin(step_x)
cos_y=np.cos(step_x)

x = torch.from_numpy(sin_y[np.newaxis, :, np.newaxis])    # shape (batch, time_step, input_size)
y = torch.from_numpy(cos_y[np.newaxis, :, np.newaxis])

pre_y,h_state=model(x,h_state)

plt.plot(step_x,sin_y,label='input (sin)')
plt.plot(step_x,cos_y,label='target (cos)')
plt.plot(step_x,pre_y.data.numpy().flatten(),label='pre_y')
plt.legend() #展示标签
plt.show()

运行结果:
在这里插入图片描述

六、参考资料

《零基础入门深度学习(5) - 循环神经网络》
《深度学习(五) - 序列模型》
《一文搞懂RNN(循环神经网络)基础篇》
《【Pytorch教程】:RNN 循环神经网络 (回归)》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/62412.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTML旅游景点网页作业制作——旅游中国11个页面(HTML+CSS+JavaScript)

👨‍🎓学生HTML静态网页基础水平制作👩‍🎓,页面排版干净简洁。使用HTMLCSS页面布局设计,web大学生网页设计作业源码,这是一个不错的旅游网页制作,画面精明,排版整洁,内容…

精品基于Javaweb的酒店民宿管理推荐平台SSM

《基于Javaweb的酒店民宿管理推荐平台》该项目含有源码、论文等资料、配套开发软件、软件安装教程、项目发布教程等 使用技术: 开发语言:Java 框架:ssm 技术:JSP JDK版本:JDK1.8 服务器:tomcat7 数据…

老司机发车了,CountDownLatch:等与不等都在你

哈喽大家好,我是阿Q。 前几天我们把 ReentrantLock的原理 进行了详细的讲解,不熟悉的同学可以翻看前文,今天我们介绍另一种基于 AQS 的同步工具——CountDownLatch。 CountDownLatch 被称为倒计时器,也叫闭锁,是 juc…

[论文精读|顶刊论文]Relational Triple Extraction: One Step is Enough

2022.5.11 |IJCAI-2022|华中科技大学|2022年SOTA| 原文链接 Relational Triple Extraction: One Step is Enough 过去的步骤: 寻找头尾实体的边界位置(实体识别)将特定令牌串联成三元组&…

[附源码]Python计算机毕业设计Django区域医疗服务监管可视化系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

分享107个小清新,总有一款适合您

PPT链接:https://pan.baidu.com/s/1WqaR_29avEgq46iTSLKfmw?pwd5r81 提取码:5r81 源码下载链接:ppt.rar - 蓝奏云 采集的参数 page_count 1 # 每个栏目开始业务content"text/html; charsetgb2312"base_url "https://sc…

Python可视化招聘信息聚合系统 (附源码)!

前言 基于数据技术的互联网行业招聘信息聚合系统,本系统以Python为核心,依托web展示,所有功能在网页就可以完成操作,爬虫、分析、可视化、互动独立成模块,互通有无。 依托python的丰富库实现,爬虫使用Req…

详解设计模式:备忘录模式

详解设计模式:备忘录模式 备忘录模式(Memento Pattern)也被称为快照模式(Snapshot Pattern)、Token 模式(Token Pattern),是在 GoF 23 种设计模式中定义了的行为型模式。 备忘录模式…

阿里云存储解决方案,助力轻舟智航“将无人驾驶带进现实”

轻舟智航介绍 轻舟智航是一家以“将无人驾驶带进现实”为使命的自动驾驶通用解决方案公司,依赖双擎战略,一方面主张以高性价比的前装量产方案,致力于打造L4级体验的城市高速NOA方案,满足不同客户不同等级的自动驾驶量产需求。另一…

Lottie 动画导出为 GIF/MP4 以及与 QML 集成演示

获取 Lottie 动画文件 lottiefiles 是一个很好的网站, 从上面可以下载到别人分享的 lottie 动画文件. 我们可以下载到多种格式, 下面分别讲解每个格式的下载和适用情景. 下载 JSON 源文件 这是体积最小的格式, 一般在 10kb ~ 100kb 之间. 考虑到 lottiefiles 的服务器在国外…

锂热电池检测设备 你一定没见过这种检测方式!

项目需求 用户希望纳米Namisoft帮他们设计开发一款系统,要求系统软件安装在PC控制装置上,系统通过使用USB、RS232、LAN通讯接口实现对锂电池测试过程中所用到的仪器(内阻测试仪、扫码枪、触摸显示器和电源模块等)进行软件控制&…

浸没式冷却-散热技术新趋势,一起学Flotherm电子元器件散热仿真

作者:Billy,仿真秀专栏作者 随着电子元器件功率的上升,散热成为技术发展的瓶颈之一。单纯的风冷在一些情况下无法满足散热需求,直接式液冷和间接式液冷因其可以提供更大量级的对流换热系数,带走更多的热量&#xff0c…

基于智能优化算法PSO/GWO/AFO+柔性车间生产调度(Matlab代码实现)

目录 1 柔性车间生产调度 2 运行结果 3 参考文献 4 Matlab代码实现 1 柔性车间生产调度 随着经济全球化的不断加深和市场竞争的日益严峻,传统的单一车间制造模式已经无法满足我国制造业的生产需求,分布式生产制造模式已经成为企业提高生产竞争力的重要手段。由于不同工厂之…

[附源码]计算机毕业设计校友社交系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

目标检测算法——3D公共数据集汇总 2(附下载链接)

>>>深度学习Tricks&#xff0c;第一时间送达<<< &#x1f384;&#x1f384;近期&#xff0c;小海带在空闲之余&#xff0c;收集整理了一批3D公共数据集供大家参考。 整理不易&#xff0c;小伙伴们记得一键三连喔&#xff01;&#xff01;&#xff01;&…

微服务自动化【Docker-Compose】

目录 1. docker-compose 2. docker-compose安装与配置 3. docker-compose.yml 配置文件基本介绍 3.1 version: 指定 docker-compose.yml 文件的写法格式 3.2 services&#xff1a;多个容器集合 4. docker-compose 基本指令 5. docker-compose 网络 5.1 指定网络模式 …

不懂业务不清楚指标?这40套可视化大屏模板,让你突破职场天花板

报表可以说是中国职场的一大特色&#xff0c;不少职场人需要每天做各种报表给领导或者业务决策者看&#xff0c;为此甚至诞生了不少的“表哥表姐”。但很多人在做报表的时候其实并不懂业务&#xff0c;需要找业务确定业务指标才做的下去。 今天我就分享40多个报表模板&#xf…

Spark 3.0 - 10.Ml 常用 Sample 采样方法

目录 一.引言 二.数据准备 三.随机采样 Sample 四.按权重拆分 randomSplit 五.分层采样 sampleByKey 六.总结 一.引言 使用 Spark 进行机器学习、数据分析等项目时&#xff0c;常常需要对数据进行采样&#xff0c;下面介绍三种最常用的采样方法&#xff1a; A.随机采样:…

Vue3 学习笔记 —— 自动导入 Vue3 APIs、v-model

目录 1. 自动导入 Vue3 APIs —— unplugin-auto-import/vite 2. v-model 2.1 相较于 Vue2&#xff0c;Vue3 做出了哪些变化&#xff1f; 2.2 绑定一个 v-model 2.2.1 父组件 2.2.2 子组件 2.3 绑定多个 v-model 2.3.1 父组件 2.3.2 子组件 2.4 v-model 中的自定义修…

Android Material Design之ShapeableImageView(十三)

效果图 资源引入 implementation com.google.android.material:material:1.4.0属性 属性描述android:id控件idandroid:layout_width控件长度android:layout_height控件高度app:shapeAppearance控件外观样式app:strokeWidth画笔粗度app:strokeColor画笔颜色android:src图像资源…