【Python深度学习】RNN循环神经网络结构讲解及序列回归问题实战(图文解释 附源码)

news2025/1/19 23:14:02

需要全部代码请点赞关注收藏后评论区留言私信~~~

循环神经网络

循环神经网络(Recurrent Neural Network,RNN)是用于对序列的非线性特征进行学习的深度神经网络。循环神经网络的输入是有前后关联关系的序列。

循环神经网络可以用来解决与序列有关的问题,如序列回归、序列分类和序列标注等任务。序列的回归问题,如气温、股票价格的预测问题,它的输入是前几个气温、股票价格的值,输出的是连续的预测值。序列的分类问题,如影评的正负面分类、垃圾邮件的检测,它的输入是影评和邮件的文本,输出的是预定的有限的离散的标签值。序列的标注问题,如自然语言处理中的中文分词和词性标注,循环神经网络可处理传统机器学习中的隐马尔可夫模型、条件随机场等模型胜任的标注任务。

类似隐马尔可夫链,把循环神经网络基本结构的中间部分称为隐层,向量s标记了隐层的状态。隐层的输出有两个,一个是y,另一个反馈到自身。到自身的反馈将与下一步的输入共同改变隐层的状态s。因此,隐层的输入也有两个,分别是当前输入x和来自自身的反馈(首步没有来自自身的反馈)。

 输入样本的观测序列有两个分量x^(1),x^(2),即每次输入的步长数为2。观测序列的分量是3维的向量。隐状态是一个2维的向量s。输出是1维的标量,分别是y^(1),y^(2)。

TensorFlow2中Keras的SimpleRNN的类原型

tf.keras.layers.SimpleRNNCell(
    units, activation='tanh', use_bias=True,
    kernel_initializer='glorot_uniform',
    recurrent_initializer='orthogonal',
    bias_initializer='zeros', kernel_regularizer=None,
    recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None,
    recurrent_constraint=None, bias_constraint=None, dropout=0.0,
    recurrent_dropout=0.0, **kwargs
)

参数units设定该单元的状态向量s的维数。参数use_bias设定是否使用阈值参数θ。

用SimpleRnnCell来模拟循环神经网络前向传播

import tensorflow as tf
 
# (批大小, 步长数, 序列分量维数)
batch_size = 1
time_step = 2
step_dim = 3
hidden_dim = 2  # 隐状态维度
 
s0 = tf.constant([[0.0, 0.0]]) # 第1步输入的隐状态
x1 = tf.constant([[0.1, 0.2, 0.3]]) # 第1步输入的序列分量
simpleRnnCell = tf.keras.layers.SimpleRNNCell(hidden_dim , use_bias=False)
out1,s1 = simpleRnnCell(x1, [s0]) # 将当前步的x和上一步的隐状态输入到单元中,产生第1步的输出和隐状态
print("out1:", out1)
print("s1:", s1)
>>> out1: tf.Tensor([[-0.05700448  0.2253606 ]], shape=(1, 2), dtype=float32)
     s1: [<tf.Tensor: id=53, shape=(1, 2), dtype=float32, numpy=array([[-0.05700448,  0.2253606 ]], dtype=float32)>]
x2 = tf.constant([[0.2, 0.3, 0.4]]) # 第2步输入的序列分量
out2,s2 = simpleRnnCell(x2, [s1[0]]) # 将当前步的x和上一步的隐状态输入到单元中,产生第2步的输出和隐状态
print("out2:", out2)
print("s2:", s2)
>>> out2: tf.Tensor([[-0.198356    0.54249984]], shape=(1, 2), dtype=float32)
     s2: [<tf.Tensor: id=62, shape=(1, 2), dtype=float32, numpy=array([[-0.198356  ,  0.54249984]], dtype=float32)>]

网络结构

 one to many结构是单输入多输出的结构,可用于输入图片给出文字说明。many to one结构是多输入单输出的结构,可用于文本分类任务,如影评情感分类、垃圾邮件分类等。many to many delay结构也是多输入多输出的结构,但它是有延迟的输出,该结构常用于机器翻译,机器问答等。

序列回归问题实战

该示例是对三角函数的值进行预测,先对sin三角函数值顺序采点,然后用一段值序列来预测紧接的第1个值。

基本结构采用了TensorFlow中Keras的SimpleRNN,它实现了RNN基本单元。它的输入有两个重要的参数:units和input_shape。units是设定该单元的状态向量s的维数,它的大小决定了W矩阵的维度。input_shape设定了输入的序列的长度和每个序列元素的特征数,每个序列元素的特征数和units共同决定了U矩阵的维度。

输入序列的长度决定了SimpleRNN的循环步数,在最后一步,将状态向量s输出到一个全连接层,该连接层输出为1维的预测值,因此V矩阵的维度是units×1。

预测结果如下

 

 部分代码如下

import numpy as np
np.random.seed(0)

def myfun(x):
    '''目标函float):自变量
    output:函数值'''
    return np.sin(x)

x = np.linspace(0,15, 150)
y = myfun(x) + 1 + np.random.random(size=len(x)) * 0.3 - 0.15

input_len = 10

train_x = []
train_y = []
for i in range(len(y)-input_len):
    train_data = []
    for j in range(input_len):
        train_data.append([y[i+j]])
    train_x.append(train_data)
    train_y.append((y[i+input_len]))
import tensorflow as tf

model = tf.keras.Sequential()
model.add(tf.keras.layers.SimpleRNN(100, return_sequences=False, 
                    activation='relu',
                    input_shape=(input_len, 1)))
model.add(tf.keras.layers.Dense(1))
model.add(tf.keras.layers.Activation("relu"))
model.compile(lopochs=10, batch_size=10, verbose=1)

import matplotlib.pyplot as plt
plt.rcParams['axes.unicode_minus']=False
plt.rc('font', family='SimHei', size=13)
#plt.scatter(x, y, color="black", linewidth=1)
y0 = myfun(x) + 1
plt.plot(x, y0, color="red", linewidth=1)
y1 = model.predict(train_x)
plt.plot(x[input_len:], y1, "b--", linewidth=1)
plt.show()

创作不易 觉得有帮助请点赞关注收藏~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/106124.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python数据预处理时缺失值的不同处理方式!

在使用python做数据分析的时候&#xff0c;经常需要先对数据做统一化的处理&#xff0c;缺失值的处理是经常会使用到的。 一般情况下&#xff0c;缺失值的处理要么是删除缺失数据所在的行&#xff0c;要么就是对缺失的单元格数据进行填充。 今天介绍的是使用差补法/均值/固定…

CSS -- 08. 移动WEB开发之流式布局

文章目录移动WEB开发之流式布局1 移动端基础1.1 浏览器现状1.2 手机屏幕现状1.3 移动端调试方法2 视口2.1 布局视口 layout viewport2.2 视觉视口 visual viewport2.3 理想视口 ideal viewport2.4 meta视口标签3 二倍图3.1 物理像素&物理像素比3.2 多倍图3.3 背景缩放 back…

clickhouse-部署指南(亲测超详细)

文章目录环境要求测试环境Tar方式安装clickhouse用户授权访问控制服务配置启动连接测试TABIX webUI性能测试环境要求 ClickHouse可以在任何具有x86_64&#xff0c;AArch64或PowerPC64LE CPU架构的Linux&#xff0c;FreeBSD或Mac OS X上运行 官方预构建的二进制文件通常针对x86…

prometheus+grafana对数据库mysql监控

安装 mysql docker run --name mysql-test -p MYSQL ROOT_PASSWORD123456 -p23306:3306 mysql:5.7.26启动镜像的时候 已经是把docker容器内部的3306端口 映射到本机了 直接通过navicat连上 进入mysql容器内部 docker exec -it 容器id /bin/bash, 启动 mysql. docker start d…

33.flink cdc 实时数据同步利器

什么是flink cdc? 对很多初入门的人来说是无法理解cdc到底是什么个东西。 有这样一个需求&#xff0c;比如在mysql数据库中存在很多数据&#xff0c;但是公司要把mysql中的数据同步到数据仓库(starrocks), 数据仓库你可以理解为存储了各种各样来自不同数据库中表。 数据的同步…

2-2-3-7、FutureCompletableFuture详解

Runnable 线程的任务接口&#xff0c;用于定义被执行任务方法的抽象&#xff0c;是函数式接口&#xff08;仅存在一个需要实现方法的接口&#xff09;&#xff0c;其方法为run方法通过对并发编程中java线程的了解&#xff0c;Thread调用start方法&#xff0c;最后操作系统会通过…

Confluence 调整会话超时(session timeout)

文章目录前言一、概括二、实际场景应用1.更改空闲超时2.更改记住我 cookie 的生命周期3.在用户通过身份验证后的某个时间强制注销用户总结前言 在 Confluence 中有两个会话 Cookie&#xff1a; JSESSIONID: 由 Tomcat 使用和管理。 默认情况下&#xff0c;这被视为会话 cooki…

类与对象(中)

类与对象类的6个默认成员函数构造函数概念特性析构函数概念特性拷贝构造函数概念特性赋值运算符重载运算符重载赋值运算符重载前置 后置 重载const成员函数取地址及const取地址操作符重载类的6个默认成员函数 当类中没有任何成员时&#xff0c;称作空类 但是呢&#xff0c;编译…

Docker使用(容器、镜像相关命令)

虚拟化 在计算机中&#xff0c;虚拟化&#xff08;英语&#xff1a;Virtualization&#xff09;是一种资源管理技术&#xff0c;是将计算机的各种实体资源&#xff0c;如服务器、网络、内存及存储等&#xff0c;予以抽象、转换后呈现出来&#xff0c;打破实体结构间的不可切割…

虹科方案|将以太网连接添加到Dell EMC PowerVault™ ML3 SAS库

一、Dell EMC 和 ATTO 磁带解决方案 Dell EMC 和 ATTO 提供了业界唯一的商用解决方案&#xff0c;可将高速以太网连接添加 到标准 SAS LTO 磁带驱动器。ATTO XstreamCORE ET 8200 智能网桥允许您使用 iSCSI 和 iSER 协议通过标准以太网远程连接到 SAS 磁带驱动器。当与采用最新…

花 2 个月备战字节跳动Java岗,3 轮面试拿下 60W Offer

最近收到一位刚入职字节的 Java 工程师朋友投稿——以下内容来自其亲身经历&#xff0c;某双非硕士拿到 字节 60W offer &#xff0c;感谢他的走心分享&#xff08;文末附硬货笔记&#xff09; PART1&#xff1a;个人情况简介 菜 J 一枚&#xff0c;本硕都是计算机&#xff08…

[论文阅读] SqueezeSeg V1

文章目录1. 主要思想2. 具体方法2.1 数据处理方式2.2 网络架构3. 实验支撑4. 总结启示5. 相关文献paper 原论文的链接 code: 源代码链接 paper全称&#xff1a;SqueezeSeg: Convolutional Neural Nets with Recurrent CRF for Real-Time Road-Object Segmentation from 3D LiDA…

【02】FreeRTOS获取10.4.6源码+移植到STM32F407步骤

目录 1.获取FreeRTOS源码 1.1 FreeRTOS官网下载步骤 1.2FreeRTOS源码内容 1.3FreeRTOS内核文件 1.3.1Demo文件夹 1.3.2Source文件夹 2.FreeRTOS移植 2.1添加FreeRTOS源码 2.1.1复制FreeRTOS源码 2.1.2将文件添加到工程 2.1.3添加头文件路径 2.2添加FreeRTOS.h 2.3修改SYS…

vpp process类型节点调度过程

vpp节点类型 VLIB_NODE_TYPE_PROCESS&#xff1a;process类型节点可以被挂起也可以被恢复&#xff0c;main线程上调度 &#xff08;免费订阅,永久学习&#xff09;学习地址: Dpdk/网络协议栈/vpp/OvS/DDos/NFV/虚拟化/高性能专家-学习视频教程-腾讯课堂 process节点注册 pro…

【MC】新加载器 Quilt 好用吗?和 Fabric 相比好在哪?

在今年四月 (2022/4/20) &#xff0c;一个船新加载器 Quilt 发布了第一个测试版。 Quilt officially entered its first beta today, attracting an influx of new users and an amazing amount of support and positive feedback. By the end of the day, Quilt was happily l…

Go语言设计与实现 -- 字符串

Go语言的字符串与Java和python是一样的。具有不可变性。是一个只读的字节数组&#xff0c;如图所示。 因为Go的字符串具有不可变性&#xff0c;所以我们只能通过string和[]byte类型之间反复转换实现修改。 将这一段内存复制到栈上将变量的类型转换成[]byte后并修改字节数据将修…

功能上新|使用 Excel 低门槛进行指标分析!

Kyligence Zen 功能上新啦&#xff01;用户不仅可以在 Kyligence Zen 中定义、分析和管理指标&#xff0c;还可直接使用 Excel 插件来分析 Kyligence Zen 中已经定义好的指标&#xff0c;学习无门槛&#xff0c;上手更轻松&#xff01;欢迎访问 http://zen.kyligence.io 申请免…

实验二A 图像的空域(源代码一站式复制粘贴)

实验二A 图像的空域一、实验目的二、实验原理三、实验内容与要求四、实验的具体实现一、实验目的 1.掌握图像滤波的基本定义及目的。 2.理解空间域滤波的基本原理及方法。 3.掌握进行图像的空域滤波的方法。 二、实验原理 1.空域增强 空域滤波是在图像空间中借助模板对图像进…

阳哥JUC并发编程之AQS后篇全网最详细源码笔记

文章目录AQS后序课程笔记AQS源码ReentryLock锁的原理分析公平锁以及非公平锁源码详解Aquire方法调用原码流程分析第一步、tryAquire第二步、addwrite第三步&#xff1a;aquireQueuedAQS释放锁的过程第一步、释放锁第二步进入aquireQueueAQS异常情况下走Cancel流程分析第一种队尾…

ECharts项目实战:全球GDP数据可视化

【课程简介】 可视化是前端里一个几乎可以不用写网页&#xff0c;但又发展得非常好的方向。在互联网产品里&#xff0c;无论是C端中常见的双十一购物节可视化大屏&#xff0c;还是B端的企业中后台管理系统都离不开可视化。国家大力推动的智慧城市、智慧社区中也有很多可视化的…