RNN循环神经网络

news2024/12/22 16:42:19

目录

一、卷积核与循环核

二、循环核

1.循环核引入

2.循环核:循环核按时间步展开。

3.循环计算层:向输出方向生长。

4.TF描述循环计算层

三、TF描述循环计算

四、RNN使用案例

1.数据集准备

 2.Sequential中RNN

3.存储模型,acc和lose可视化曲线和测试交互窗口

五、Embedding编码

Embedding改进上述案例

六、LSTM

1.循环时间核计算过程​编辑

 2.TF描述LSTM层

3.代码

七、GRU

1.循环时间核计算过程

2.TF描述GRU层

3.代码

总结


一、卷积核与循环核

卷积核:参数空间共享,卷积层提取空间信息。

循环核:参数时间共享,循环层提取时间信息

前向传播:只有ht会变化

反向传播:Wxh、Whh、Why才会变换

二、循环核

1.循环核引入

2.循环核:循环核按时间步展开。

3.循环计算层:向输出方向生长。

4.TF描述循环计算层

# 如何使用
tf.keras.layers.SimpleRNN(记忆体个数,activation=‘激活函数’ ,
return_sequences=是否每个时刻输出ht到下一层)
# 说明
activation=‘激活函数’ (不写,默认使用tanh)
return_sequences=True 各时间步输出ht
return_sequences=False 仅最后时间步输出ht(默认)
例:SimpleRNN(3, return_sequences=True

三、TF描述循环计算

举例说明

 计算预测abcde

四、RNN使用案例

1.数据集准备

input_word = "abcde"
w_to_id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4}  # 单词映射到数值id的词典
id_to_onehot = {0: [1., 0., 0., 0., 0.], 1: [0., 1., 0., 0., 0.], 2: [0., 0., 1., 0., 0.], 3: [0., 0., 0., 1., 0.],
                4: [0., 0., 0., 0., 1.]}  # id编码为one-hot

x_train = [id_to_onehot[w_to_id['a']], id_to_onehot[w_to_id['b']], id_to_onehot[w_to_id['c']],
           id_to_onehot[w_to_id['d']], id_to_onehot[w_to_id['e']]]
y_train = [w_to_id['b'], w_to_id['c'], w_to_id['d'], w_to_id['e'], w_to_id['a']]
# 打乱顺序
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)

 2.Sequential中RNN

# Sequential层中的循环序列Rnn,时间序列
model = tf.keras.Sequential([
    SimpleRNN(3), # 存储单元
    Dense(5, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/rnn_onehot_1pre1.ckpt"

if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True,
                                                 monitor='loss')  # 由于fit没有给出测试集,不计算测试集准确率,根据loss,保存最优模型

history = model.fit(x_train, y_train, batch_size=32, epochs=100, callbacks=[cp_callback])

model.summary()

3.存储模型,acc和lose可视化曲线和测试交互窗口

# print(model.trainable_variables)
file = open('./weights.txt', 'w')  # 参数提取
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.title('Training Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

############### predict #############

preNum = int(input("input the number of test alphabet:"))
for i in range(preNum):
    alphabet1 = input("input test alphabet:")
    alphabet = [id_to_onehot[w_to_id[alphabet1]]]
    # 使alphabet符合SimpleRNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,所以循环核时间展开步数为1; 表示为独热码有5个输入特征,每个时间步输入特征个数为5
    alphabet = np.reshape(alphabet, (1, 1, 5))
    result = model.predict([alphabet])
    pred = tf.argmax(result, axis=1)
    pred = int(pred)
    tf.print(alphabet1 + '->' + input_word[pred])

五、Embedding编码

目的:每次时间预测都要进行编码,而且编码很麻烦,并且如果简单的话站的内存就比较多,所以就引入Embedding

独热码:数据量大 过于稀疏,映射之间是独立的,没有表现出关联性

Embedding:是一种单词编码方法,用低维向量实现了编码, 这种编码通过神经网络训练优化,能表达出单词间的相关性。

tf.keras.layers.Embedding(词汇表大小,编码维度)

Embedding改进上述案例

# Sequential有变化
model = tf.keras.Sequential([
    Embedding(5, 2),
    SimpleRNN(3),
    Dense(5, activation='softmax')
])

#  输入有变化
     alphabet = [w_to_id[alphabet1]]
    # 使alphabet符合Embedding输入要求:[送入样本数, 循环核时间展开步数]。
    # 此处验证效果送入了1个样本,送入样本数为1;输入1个字母出结果,循环核时间展开步数为1。
    alphabet = np.reshape(alphabet, (1, 1))

六、LSTM

1.循环时间核计算过程

 2.TF描述LSTM层

tf.keras.layers.LSTM(记忆体个数,return_sequences=是否返回输出) return_sequences=True 各时间步输出ht return_sequences=False 仅最后时间步输出ht(默认

3.代码

model = tf.keras.Sequential([
    LSTM(80, return_sequences=True),
    Dropout(0.2),
    LSTM(100),
    Dropout(0.2),
    Dense(1)
])

七、GRU

1.循环时间核计算过程

2.TF描述GRU层

tf.keras.layers.GRU(记忆体个数,return_sequences=是否返回输出) return_sequences=True 各时间步输出ht return_sequences=False 仅最后时间步输出ht(默认

3.代码

model = tf.keras.Sequential([
    GRU(80, return_sequences=True),
    Dropout(0.2),
    GRU(100),
    Dropout(0.2),
    Dense(1)
])

总结

  • 本文主要借鉴:mooc曹健老师的《人工智能实践:Tensorflow笔记》
  • RNN 是最简单的循环神经网络,它的优点是结构简单,易于实现,但是也有缺点,比如梯度消失或爆炸、难以处理长期依赖等。
  • LSTM 是一种改进的 RNN,它的优点是能够避免梯度消失和长期依赖问题,学习更长的序列,但是也有缺点,比如参数较多,计算复杂度高。
  • GRU 是一种简化的 LSTM,它的优点是参数较少,计算速度快,但是也有缺点,比如表达能力可能不如 LSTM 强 。
  • 选择情况:一般来说,LSTM 和 GRU 的表现要优于 RNN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/948265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MAC M2芯片执行yolov8 + deepsort 实现目标跟踪

MAC M2芯片执行yolov8 deepsort 实现目标跟踪 MAC M2 YoloX bytetrack实现目标跟踪 实验结果 MAC mps显存太小了跑不动 还是得用服务器跑 需要实验室的服务器跑 因为网上花钱跑4天太贵了!!! 步骤过程尝试: 执行mot17 数据集 …

FSPI的PCB设计

FSPI是一种灵活的串行接口控制器,RK3588芯片中有1个FSPI控制器,可用来连接FSPI设备。 RK3588 FSPI 控制器有如下特点: 1)支持串行NOR Flash,串行Nand Flash; 2)支持SDR模式; 3&am…

Android屏幕显示 android:screenOrientation configChanges 处理配置变更 代码中动态切换横竖屏

显示相关 屏幕朝向 https://developer.android.com/reference/android/content/res/Configuration.html#orientation 具体区别如下: activity.getResources().getConfiguration().orientation获取的是当前设备的实际屏幕方向值,可以动态地根据设备的旋…

【STM32】学习笔记(EXTI)-江科大

EXTI外部中断 中断:在主程序运行过程中,出现了特定的中断触发条件(中断源),使得CPU暂停当前正在运行的程序,转而去处理中断程序,处理完成后又返回原来被暂停的位置继续运行 中断优先级&#x…

【Python数据分析】数据分析之numpy基础

实验环境:建立在Python3的基础之上 numpy提供了一种数据类型,提供了数据分析的运算基础,安装方式 pip install numpy导入numpy到python项目 import numpy as np本文以案例的方式展示numpy的基本语法,没有介绍语法的细枝末节&am…

电商数仓项目需求及架构设计

一、项目需求 1.用户行为数据采集平台搭建 2.业务数据采集平台搭建 3.数仓维度建模 4.统计指标 5.即席查询工具,随时进行指标分析 6.对集群性能进行监控,发生异常时报警(第三方信息) 7.元数据管理 8.质量监控 9.权限管理&#xff…

streamlit-API

介绍 安装 pip install streamlit运行 streamlit run your_script.py [-- script args]数据流 多页应用程序 API 文本元素 数据元素 图标元素 输入小部件 媒体元素 布局和容器 聊天元素 st.chat_message st.chat_input 显示进度和状态 控制流 占位符/帮助/选项 图表改变 会…

SpringBoot工具类—基于定时器完成文件清理功能

直接复制粘贴既可!! import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; import java.io.File; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneOff…

云职达(上海)岗前实训基地(上海云职达):致力为企业提供良好的数字化解决方案

上海云职达全称:云职达上海信息科技有限公司,是一家致力于推动算力产业发展的企业。随着数字经济时代的到来,算力作为数字产业化和产业数字化转型的关键支撑,已经成为推进中国式现代化的重要驱动力量。云职达深入理解算力产业的重…

星辰天合荣获“2023年度优秀光伏行业数字化供应商”

8 月 28 日,由 OFweek 维科网及旗下权威的光伏专业媒体-维科网光伏共同举办的“OFweek 2023(第十四届)太阳能光伏产业大会暨光伏行业年度颁奖典礼”在深圳成功举办。 星辰天合凭借在光伏领域的优秀智能存储解决方案,以及大量的应用…

Rdedis 持久化

Redis 是内存数据库,如果不将内存中的数据库状态保存到磁盘,那么一旦服务器进程退出,服务器中的数据库状态也会消失。所以 Redis 提供了持久化功能! 一、RDB(Redis DataBase) 1.1 概念 在指定的时间间隔内…

冠达管理:2023股票交易新规则详解?股票手续费包括哪些?

投资者进行股票投资时不是随便就可以进行生意的,需求恪守一定的生意规则,才干顺利成交。那么2023股票生意新规则详解?股票手续费包含哪些?下面就由冠达管理为大家分析: ​ 2023股票生意新规则详解? 1、约…

EDFHG-04-200-3C2-XY-31T001电液比例大流量调速阀放大器

EDFHG-03-100-3C40-XY-30T、EDFHG-03-100-3C2-XY-30T、EDFHG-04-140-3C40-XY-30T、EDFHG-04-140-3C2-XY-30T、EDFHG-06-140-3C40-XY-30T、EDFHG-06-140-3C2-XY-30T、EDFHG-04-200-3C2-XY-31T001、EDFHG-06-400-3C2-XY-31T001、EDFHG-06-400-3C40-XY-31T001电液比例换向调速阀采…

vue3 ref reactive响应式数据 赋值的问题

文章目录 vue3 ref reactive响应式数据 赋值的问题场景1:将响应式数据赋值请求后的数据错误示范:直接赋值正确写法 场景2:响应式数据解构之后失去响应式原因分析解决办法 toRefs/toRef方法创建ref引用对象 vue3 ref reactive响应式数据 赋值的问题 doing…

云备份——项目介绍

一,项目基本介绍 自动将本地计算机上指定文件夹中需要备份的文件上传备份到服务器中。并且能够随时通过浏览器进行查看并且下载,其中下载过程支持断点续传功能,而服务器也会对上传文件进行热点管理,将非热点文件进行压缩存储&…

在springboot中配置mybatis(mybatis-plus)mapper.xml扫描路径的问题

我曾经遇到过类似问题: mybatis-plus的mapper.xml在src/main/java路径下如何配置pom.xml和application.yml_idea 把mapper文件放到java下如何配置_梓沂的博客-CSDN博客 当时只是找到解决问题的办法,但对mybatis配置来龙去脉并未深入了解,所…

无意间发现这款可以免费制作3D翻页电子画册的网站

在博主努力的搜寻下,无意间发现这个网站,可以免费制作3D翻页电子画册。使用这个网站非常简单,只需上传你想要展示的图片和添加相应的文字,然后选择合适的模板和风格。接下来,就会自动转化成漂亮的3D翻页画册 工具嘛&am…

《安富莱嵌入式周报》第321期:开源12导联便携心电仪,PCB AI设计,150M示波器差分探头,谷歌全栈环境IDX,微软在Excel推出Python

周报汇总地址:嵌入式周报 - uCOS & uCGUI & emWin & embOS & TouchGFX & ThreadX - 硬汉嵌入式论坛 - Powered by Discuz! 视频版: https://www.bilibili.com/video/BV1ju4y1D7A8/ 《安富莱嵌入式周报》第321期:开源12导…

天津web前端培训机构 Web开发是做什么的?

随着互联网和移动互联网的快速发展,越来越多的企业开始注重自身网站和应用程序的用户体验和设计,而这些方面正是Web前端开发人员所擅长的领域。Web前端不仅招聘市场需求量大,还有一个重要的原因就是,入行门槛低,入门简…

A+CLUB管理人支持计划第七期 | 吾执投资

免责声明 本文内容仅对合格投资者开放! 私募基金的合格投资者是指具备相应风险识别能力和风险承担能力,投资于单只私募基金的金额不低于100 万元且符合下列相关标准的单位和个人: (一)净资产不低于1000 万元的单位&…