边写代码边学习之RNN

news2024/12/23 18:26:32

1. 什么是 RNN

循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图:

在这里插入图片描述

 x是输入,h是隐层单元,o为输出,L为损失函数,y为训练集的标签.
这些元素右上角带的t代表t时刻的状态,其中需要注意的是,因策单元h在t时刻的表现不仅由此刻的输入决定,还受t时刻之前时刻的影响。V、W、U是权值,同一类型的权连接权值相同。
有了上面的理解,前向传播算法其实非常简单,对于t时刻:
                                       h ^{(t)} =\phi (Ux^{(t)} +Wh^{(t-1)} +b)

其中\phi ()为激活函数,一般来说会选择tanh函数,b为偏置。
t时刻的输出就更为简单:
                                                     o^{(t)} =Vh ^{(t)} +c
最终模型的预测输出为:
                                                          \hat y^{(t)} =\sigma (o^{(t)} )
其中\sigma为激活函数,通常RNN用于分类,故这里一般用softmax函数。

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

def simple_rnn_layer():

    # Create a dense layer with 10 output neurons and input shape of (None, 20)
    model = Sequential()
    model.add(SimpleRNN(units=3, input_shape=(3, 2),))  # 3 units in the RNN layer, input_shape=(timesteps, features)
    model.add(Dense(1))  # Output layer with one neuron

    # Print the summary of the dense layer
    print(model.summary())
if __name__ == '__main__':
    simple_rnn_layer()

输出

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 simple_rnn (SimpleRNN)      (None, 3)                 18        
                                                                 
 dense (Dense)               (None, 1)                 4         
                                                                 
=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________
None

2.2. 验证RNN里的逻辑

写代码验证这个过程,看看结果是不是一样的。

import keras.optimizers.optimizer
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense
def change_weight():
    # Create a simple Dense layer
    rnn_layer = SimpleRNN(units=3, input_shape=(3, 2), activation=None, return_sequences=True)

    # Simulate input data (batch size of 1 for demonstration)
    input_data = np.array([
                [[1.0, 2], [2, 3], [3, 4]],
                [[5, 6], [6, 7], [7, 8]],
                [[9, 10], [10, 11], [11, 12]]
        ])

    # Pass the input data through the layer to initialize the weights and biases
    _ = rnn_layer(input_data)

    # Access the weights and biases of the dense layer
    kernel, recurrent_kernel, biases = rnn_layer.get_weights()

    # Print the initial weights and biases
    print("recurrent_kernel:", recurrent_kernel) # (3,3)
    print('kernal:',kernel) #(2,3)
    print('biase: ',biases) # (3)

    kernel = np.array([[1, 0, 2], [2, 1, 3]])
    recurrent_kernel = np.array([[1, 2, 1.0], [1, 0, 1], [0, 1, 0]])
    biases = np.array([0, 0, 1.0])

    rnn_layer.set_weights([kernel, recurrent_kernel, biases])
    print(rnn_layer.get_weights())

    test_data = np.array([
        [[1.0, 3], [1, 1], [2, 3]]
    ])

    output = rnn_layer(test_data)

    print(output)

if __name__ == '__main__':
    change_weight()

输出结果如下:可以看到结果是我手算的是一致的。

recurrent_kernel: [[ 0.06973135  0.40464386  0.9118119 ]
 [ 0.6186313  -0.7345941   0.27868783]
 [ 0.7825809   0.5446422  -0.3015495 ]]
kernal: [[-0.48868906  0.52718353 -0.08321357]
 [-1.0569452  -0.9872779   0.72809434]]
biase:  [0. 0. 0.]
[array([[1., 0., 2.],
       [2., 1., 3.]], dtype=float32), array([[1., 2., 1.],
       [1., 0., 1.],
       [0., 1., 0.]], dtype=float32), array([0., 0., 1.], dtype=float32)]
tf.Tensor(
[[[ 7.  3. 12.]
  [13. 27. 16.]
  [48. 45. 54.]]], shape=(1, 3, 3), dtype=float32)

2.3 代码实现一个简单的例子

import keras.optimizers.optimizer
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense

# Sample sequential data
# Each sequence has three timesteps, and each timestep has two features
data = np.array([
    [[1, 2], [2, 3], [3, 4]],
    [[5, 6], [6, 7], [7, 8]],
    [[9, 10], [10, 11], [11, 12]]
])


print('data.shape= ',data.shape)
# Define the RNN model
model = Sequential()
model.add(SimpleRNN(units=4, input_shape=(3, 2), name="simpleRNN"))  # 4 units in the RNN layer, input_shape=(timesteps, features)
model.add(Dense(1, name= "output"))  # Output layer with one neuron

# Compile the model
model.compile(loss='mse', optimizer=keras.optimizers.Adam(learning_rate=0.01))

# Print the model summary
model.summary()

before_RNN_weight = model.get_layer("simpleRNN").get_weights()
print('before train ', before_RNN_weight)

# Train the model
model.fit(data, np.array([[10], [20], [30]]), epochs=2000, verbose=1)

RNN_weight = model.get_layer("simpleRNN").get_weights()
print('after train ', len(RNN_weight),)

for i in range(len(RNN_weight)):
    print('====',RNN_weight[i].shape, RNN_weight[i])


# Make predictions
predictions = model.predict(data)
print("Predictions:", predictions.flatten())

代码输出

data.shape=  (3, 3, 2)
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 simpleRNN (SimpleRNN)       (None, 4)                 28        
                                                                 
 output (Dense)              (None, 1)                 5         
                                                                 
=================================================================
Total params: 33
Trainable params: 33
Non-trainable params: 0
_________________________________________________________________
before train  [array([[-0.00466371,  0.53100157,  0.5298798 ,  0.05514288],
       [-0.08896947,  0.43185067,  0.7861788 , -0.80616236]],
      dtype=float32), array([[-0.10712242, -0.03620092, -0.02182053, -0.9933471 ],
       [-0.6549012 , -0.02620655,  0.7532524 ,  0.05503315],
       [-0.01986913,  0.9989996 ,  0.02001702, -0.03470401],
       [-0.74781984,  0.00159313, -0.657065  ,  0.09502006]],
      dtype=float32), array([0., 0., 0., 0.], dtype=float32)]
2023-08-05 16:02:44.111298: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Epoch 1/2000
....
Epoch 1999/2000
1/1 [==============================] - 0s 11ms/step - loss: 0.0071
Epoch 2000/2000
1/1 [==============================] - 0s 13ms/step - loss: 0.0070
after train  3
==== (2, 4) [[ 0.27645147  0.6025058   1.6083356  -0.38382724]
 [ 0.11586202  0.32901326  1.4760928  -1.2268958 ]]
==== (4, 4) [[-0.99628973 -2.444563    1.7412992  -1.5265529 ]
 [ 0.80340594  0.9488743   2.44552    -0.7439341 ]
 [-0.1827681  -1.3091801   1.547736   -0.6644555 ]
 [-0.5724374   2.3090494  -2.1779017   0.35992467]]
==== (4,) [-0.40184066 -1.2391611   0.33460653 -0.29144585]
1/1 [==============================] - 0s 78ms/step
Predictions: [10.000422 19.999924 29.85534 ]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/840362.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

有哪些好用的AI绘画网站?

随着人工智能技术的发展,人工智能绘画工具逐渐成为数字艺术领域的热门话题。人工智能绘画工具是利用深度学习和其他技术来模拟绘画过程和效果的工具,可以帮助用户快速创作高质量的艺术作品。除了Midjourney、除了openai等流行的AI绘画工具外,…

Flutter游戏引擎Flame系列笔记 - 1.Flame引擎概述

Flutter游戏引擎Flame系列笔记 1.Flame引擎概述 - 文章信息 - Author: 李俊才(jcLee95) Visit me at: https://jclee95.blog.csdn.netEmail: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/qq_28550263/article/details/132119035 【介绍】…

【微信小程序创作之路】- 小程序远程数据请求、获取个人信息

【微信小程序创作之路】- 小程序远程数据请求、获取个人信息 第七章 小程序远程数据请求、获取个人信息 文章目录 【微信小程序创作之路】- 小程序远程数据请求、获取个人信息前言一、远程数据请求1.本地环境2.正式域名 二、获取用户个人信息1.展示当前用户的身份信息2.获取用…

Vue电商项目--导航守卫

导航守卫理解 导航 守卫 导航:表示路由正在发送改变,进行路由跳转 守卫:你把它当中‘紫禁城守卫’ 全局守卫:你项目中,只要路由变化,守卫就能监听到。 举例:紫禁城【皇帝,太子】…

sk_buff操作函数学习

一. 前言 内核提供了大量实用的操作sk_buff的函数,在开发网络设备驱动程序和修改网络协议栈代码时需要用到。这些函数从功能上可以分为三类:创建,释放和复制socket buffer;操作sk_buff结构中的参数和指针;管理socket b…

XML 学习笔记 7:XSD

本文章内容参考自: W3school XSD 教程 Extensible Markup Language (XML) 1.0 (Second Edition) XML Schema 2001 XML Schema Part 2: Datatypes Second Edition 文章目录 1、XSD 是什么2、XSD 内置数据类型 - built-in datatypes2.1、基本数据类型 19 种2.1.1、基本…

从0到1自学网络安全(黑客)【附学习路线图+配套搭建资源】

前言 网络安全产业就像一个江湖,各色人等聚集。相对于欧美国家基础扎实(懂加密、会防护、能挖洞、擅工程)的众多名门正派,我国的人才更多的属于旁门左道(很多白帽子可能会不服气),因此在未来的…

【STM32】小电流FOC驱控一体板(开源)

FOC驱控一体板 主控芯片stm32f103c8t6 驱动芯片drv8313 三相电流采样 根据B站一个UP主的改的(【【自制】年轻人的第一块FOC驱动器】),大多数元器件是0805,实验室具备且便于自己动手焊接 。 晶振用的是无源晶振,体…

Bias Tee理论到实践

Bias Tees是将直流电压施加到必须传递RF/微波信号的任何组件的必不可少的组件,最常见的是需要DC电源的RF放大器。 对于窄带应用,bias tee设计和结构是简单的,只需要注意组件的自谐振频率。 然而,对于宽带应用,bias t…

基于瑞萨RA6M5的环境监测系统设计

基于瑞萨RA6M5的环境监测系统设计 1. 设计简介 本项目初步设计思路是打算以瑞萨单片机作为控制和数据处理的单元,使用温湿度,光照传感器去监测周围的环境参数,在屏幕上完成传感器数据和相关信息的显示。同时,使用WIFI无线模组与…

用于大型图像模型的 CNN 内核的最新内容

一、说明 由于OpenAI的ChatGPT的巨大成功引发了大语言模型的繁荣,许多人预见到大图像模型的下一个突破。在这个领域,可以提示视觉模型分析甚至生成图像和视频,其方式类似于我们目前提示 ChatGPT 的方式。 用于大型图像模型的最新深度学习方法…

Pytohn将matplotlib嵌入到tkinter中

文章目录 matplotlib窗口组成tkinter布局嵌入图像 matplotlib窗口组成 tkinter是Python标准库中自带的GUI工具,使用十分方便,如能将matplotlib嵌入到tkinter中,就可以做出相对专业的数据展示系统,很有竞争力。 在具体实现之前&a…

FTP使用教程

FTP使用教程 目录 一.FTP简介二.FTP搭建三.FTP使用 一.FTP简介 FTP中文为文件传输协议,简称为文传协议。它也是一个应用程序,不同的操作系统有不同的FTP应用程序,这些应用程序都遵守同一种协议以…

六寸相纸打印拼图 - opencv

准备自己打印一些照片,三寸相纸性价比低,只好买六寸来拼四张然后裁剪,不过并没有搜到提供了这个功能的工具,想想代码应该很简单,所以就造轮子了。可能其实有但是我没搜到。 轮子在这里: https://github.co…

ArraySetter

简介​ 用来展示属性类型为数组的 setter 展示​ 配置示例​ "setter": {"componentName": "ArraySetter","props": {"itemSetter": {"componentName": "ObjectSetter","props": {"c…

React 论文《ReAct: Synergizing Reasoning and Acting in Language Models》阅读笔记

文章目录 1. 简介论文摘要翻译动机和主要贡献 2. REACT : SYNERGIZING *RE*ASONING *ACT*ING3. KNOWLEDGE-INTENSIVE REASONING TASKS3.1 设置3.2 方法3.3 结果和观察 4. 决策任务5. 参考资料 1. 简介 论文摘要翻译 虽然大型语言模型(LLM)在自然语言理…

医疗实施-集成平台下门诊就诊流程详解

目录 集成平台下门诊就诊流程详解1.患者建档2. 门诊挂号3. 医生就诊4.处方开立5.费用收取、6、科室执行医嘱集成平台下门诊就诊流程详解 这篇文章是考虑医院使用了集成平台之后,门诊就诊流程详解。与我的文章《医疗实施-门诊就诊流程详解》的大致一样,供学有余力的人阅读。 …

图解java.util.concurrent并发包源码系列——深入理解ReentrantLock,看完可以吊打面试官

图解java.util.concurrent并发包源码系列——深入理解ReentrantLock,看完可以吊打面试官 ReentrantLock是什么,有什么作用ReentrantLock的使用ReentrantLock源码解析ReentrantLock#lock方法FairSync#tryAcquire方法NonfairSync#tryAcquire方法 Reentrant…

SpringBoot笔记:SpringBoot 集成 Dataway

文章目录 1、什么是 Dataway?2、主打场景3、技术架构4、整合SpringBoot4.1、maven 依赖4.2、初始化脚本4.3、整合 SpringBoot 5、Dataway 接口管理6、Mybatis 语法支持7、小结 1、什么是 Dataway? 官网地址:https://www.hasor.net/docs/guides/quickstart Da…

连通块是什么

刷题的时候遇到一个名词概念,连通块是什么? 在图论中,无向图中的连通块(也叫作连通分量)是指原图的一个子图(即该子图只包含原图中的部分或全部顶点及边),该子图任意两个顶点都能通…