RNN循环递归网络讲解与不掉包python实现

news2025/1/11 20:51:56

1.算法简介

参考论文:Elman J L. Finding structure in time[J]. Cognitive science, 1990, 14(2): 179-211.,谷歌被引次数超16000!

说到循环递归结构就不得不提到其鼻祖RNN网络。首先我们先对RNN有个初步的概念:想象一下,你正在阅读一本非常吸引人的小说。每次你翻开新的一页,你的大脑不仅会处理这一页上的内容,还会结合之前读过的所有内容来理解故事的情节、人物关系和背景设定。这就是一种“记忆”和“上下文理解”的过程,因为你大脑中的信息不是孤立的,而是连续且相互关联的。RNN就是模仿了这种“记忆”功能的神经网络。在传统的神经网络中,假设每个输入都是独立的,没有前后联系。但RNN不同,它专门设计用来处理序列数据——也就是那些按顺序排列,其中每个元素都与前后的元素有关联的数据,比如时间序列数据(股价、温度变化)、自然语言(句子、对话)等。在RNN中,有一个特殊的循环连接,让信息能够从前一个时间点传递到下一个时间点。就像你在读书时,上一页的信息会影响你对下一页的理解。RNN的这个特性让它能够记住前面的信息,这样当处理后续数据时,它就能够利用这些记忆来做出更好的决策或预测。有了一个初步概念之后,我们现在来具体的讲一讲RNN到底是什么以及其是怎么运行的。
在这里插入图片描述

2. RNN算法原理

在这里插入图片描述
RNN的核心思想是利用序列中的时序信息。与传统的前馈神经网络不同,RNN引入了循环连接,使网络能够保留之前时间步的信息。如上面的GIF图所示(引自:RNN):
其中X表示当前时刻的输入,以一个天气预报任务为例,我们将使用过去3天(前天、昨天和今天)的温度(T’)、湿度(H)和风速(W)来预测下一天的温度(T)。对于这个任务而言,我们的输入数据是[ X T t − 2 X_{T_{t-2}} XTt2, X T t − 1 X_{T_{t-1}} XTt1, X T t X_{T_{t}} XTt, X H t − 2 X_{H_{t-2}} XHt2, X H t − 1 X_{H_{t-1}} XHt1, X H t X_{H_{t}} XHt, X W t − 2 X_{W_{t-2}} XWt2, X W t − 1 X_{W_{t-1}} XWt1, X W t X_{W_{t}} XWt]即X是一个1*9的输入。我们的输出是明天的温度 y T t + 1 y_{T_{t+1}} yTt+1。即我们需要建一个RNN模型 f ( x ) f(x) f(x):
y T t + 1 = f ( X T t − 2 , X T t − 1 , X T t , X H t − 2 , X H t − 1 , X H t , X W t − 2 , X W t − 1 , X W t ) y_{T_{t+1}}=f(X_{T_{t-2}},X_{T_{t-1}},X_{T_{t}},X_{H_{t-2}},X_{H_{t-1}},X_{H_{t}},X_{W_{t-2}},X_{W_{t-1}},X_{W_{t}}) yTt+1=f(XTt2,XTt1,XTt,XHt2,XHt1,XHt,XWt2,XWt1,XWt)
那么具体该怎么实现这个模型呢?待我一步步分解:

2.1 RNN基本结构介绍

RNN的核心是其循环结构,它允许信息在序列中传递。对于我们的天气预报任务,RNN的基本结构可以表示为:
h t = t a n h ( W h h ∗ h t − 1 + W x h ∗ x t + b h ) h_t = tanh(W_hh * h_{t-1} + W_xh * x_t + b_h) ht=tanh(Whhht1+Wxhxt+bh)
y t = W h y ∗ h t + b y y_t = W_hy * h_t + b_y yt=Whyht+by
其中:

  • h t h_t ht 是当前时间步的隐藏状态
  • x t x_t xt 是当前时间步的输入
  • W h h W_hh Whh, W x h W_xh Wxh, W h y W_hy Why 是权重矩阵,也就是网络的可学习参数,是未知量。
  • b h b_h bh, b y b_y by 是偏置项
  • tanh 是激活函数

2.2 计算流程

对于我们的天气预报任务,计算流程如下:

a) 初始化隐藏状态 h 0 h_0 h0(通常为零向量)

b) 对于每个时间步 t = 1 to 3:
h t − 2 = h 0 h_{t-2}=h_0 ht2=h0
h t − 1 = t a n h ( W h h ∗ h t − 2 + W x h ∗ [ X T t − 1 , X H t − 1 , X W t − 1 ] + b h ) h_{t-1} = tanh(W_{h}h * h_{t-2}+ W_xh*[X_{T_{t-1}}, X_{H_{t-1}}, X_{W_{t-1}}]+ b_h) ht1=tanh(Whhht2+Wxh[XTt1,XHt1,XWt1]+bh)
h t = t a n h ( W h h ∗ h t − 2 + W x h ∗ [ X T t , X H t , X W t ] + b h ) h_{t} = tanh(W_{h}h * h_{t-2}+ W_xh*[X_{T_t}, X_{H_t}, X_{W_t}]+ b_h) ht=tanh(Whhht2+Wxh[XTt,XHt,XWt]+bh)
c) 最后,我们使用最终的隐藏状态来预测明天的温度:
y T t + 1 = W h y ∗ h t + b y y_{T_{t+1}} = W_hy * h_t + b_y yTt+1=Whyht+by
然后我手搓了一个RNN代码便于各位理解:

class RNN:
    def __init__(self, input_size, hidden_size, output_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # 初始化权重
        self.Wxh = np.random.randn(hidden_size, input_size) * 0.01
        self.Whh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.Why = np.random.randn(output_size, hidden_size) * 0.01

        # 初始化偏置
        self.bh = np.zeros((hidden_size, 1))
        self.by = np.zeros((output_size, 1))

        # 初始化Adam优化器
        self.optimizer = Adam({
            'Wxh': self.Wxh, 'Whh': self.Whh, 'Why': self.Why,
            'bh': self.bh, 'by': self.by
        })

    def forward(self, inputs):
        h = np.zeros((self.hidden_size, 1))
        self.last_inputs = inputs
        self.last_hs = {0: h}

        # 前向传播
        for t, x in enumerate(inputs):
            h = np.tanh(np.dot(self.Wxh, x) + np.dot(self.Whh, h) + self.bh)
            self.last_hs[t + 1] = h

        y = np.dot(self.Why, h) + self.by
        return y, h

    def backward(self, d_y):
        n = len(self.last_inputs)

        # 初始化梯度
        d_Wxh = np.zeros_like(self.Wxh)
        d_Whh = np.zeros_like(self.Whh)
        d_Why = np.zeros_like(self.Why)
        d_bh = np.zeros_like(self.bh)
        d_by = np.zeros_like(self.by)

        d_h = np.dot(self.Why.T, d_y)

        # 反向传播
        for t in reversed(range(n)):
            temp = (1 - self.last_hs[t + 1] ** 2) * d_h
            d_Wxh += np.dot(temp, self.last_inputs[t].T)
            d_Whh += np.dot(temp, self.last_hs[t].T)
            d_bh += temp

            d_h = np.dot(self.Whh.T, temp)

        d_Why = np.dot(d_y, self.last_hs[n].T)
        d_by = d_y

        # 使用Adam优化器更新参数
        self.optimizer.step({
            'Wxh': d_Wxh, 'Whh': d_Whh, 'Why': d_Why,
            'bh': d_bh, 'by': d_by
        })

3.完整训练代码

import numpy as np


class Adam:
    def __init__(self, params, lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
        self.params = params
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.m = {k: np.zeros_like(v) for k, v in params.items()}
        self.v = {k: np.zeros_like(v) for k, v in params.items()}
        self.t = 0

    def step(self, grads):
        self.t += 1
        for k in self.params.keys():
            self.m[k] = self.beta1 * self.m[k] + (1 - self.beta1) * grads[k]
            self.v[k] = self.beta2 * self.v[k] + (1 - self.beta2) * (grads[k] ** 2)
            m_hat = self.m[k] / (1 - self.beta1 ** self.t)
            v_hat = self.v[k] / (1 - self.beta2 ** self.t)
            self.params[k] -= self.lr * m_hat / (np.sqrt(v_hat) + self.epsilon)


class RNN:
    def __init__(self, input_size, hidden_size, output_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # 初始化权重
        self.Wxh = np.random.randn(hidden_size, input_size) * 0.01
        self.Whh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.Why = np.random.randn(output_size, hidden_size) * 0.01

        # 初始化偏置
        self.bh = np.zeros((hidden_size, 1))
        self.by = np.zeros((output_size, 1))

        # 初始化Adam优化器
        self.optimizer = Adam({
            'Wxh': self.Wxh, 'Whh': self.Whh, 'Why': self.Why,
            'bh': self.bh, 'by': self.by
        })

    def forward(self, inputs):
        h = np.zeros((self.hidden_size, 1))
        self.last_inputs = inputs
        self.last_hs = {0: h}

        # 前向传播
        for t, x in enumerate(inputs):
            h = np.tanh(np.dot(self.Wxh, x) + np.dot(self.Whh, h) + self.bh)
            self.last_hs[t + 1] = h

        y = np.dot(self.Why, h) + self.by
        return y, h

    def backward(self, d_y):
        n = len(self.last_inputs)

        # 初始化梯度
        d_Wxh = np.zeros_like(self.Wxh)
        d_Whh = np.zeros_like(self.Whh)
        d_Why = np.zeros_like(self.Why)
        d_bh = np.zeros_like(self.bh)
        d_by = np.zeros_like(self.by)

        d_h = np.dot(self.Why.T, d_y)

        # 反向传播
        for t in reversed(range(n)):
            temp = (1 - self.last_hs[t + 1] ** 2) * d_h
            d_Wxh += np.dot(temp, self.last_inputs[t].T)
            d_Whh += np.dot(temp, self.last_hs[t].T)
            d_bh += temp

            d_h = np.dot(self.Whh.T, temp)

        d_Why = np.dot(d_y, self.last_hs[n].T)
        d_by = d_y

        # 使用Adam优化器更新参数
        self.optimizer.step({
            'Wxh': d_Wxh, 'Whh': d_Whh, 'Why': d_Why,
            'bh': d_bh, 'by': d_by
        })


# 生成模拟数据
def generate_data(num_samples, time_steps):
    X = np.random.rand(num_samples, time_steps, 3)  # 3个特征:温度、湿度、风速
    y = np.sum(X[:, :, 0], axis=1) / 3 + np.random.normal(0, 0.1, num_samples)  # 简单地用平均温度加噪声作为目标
    return X, y.reshape(-1, 1)


# 数据标准化
def normalize(data):
    return (data - np.mean(data)) / np.std(data)


# 生成训练和测试数据
X_train, y_train = generate_data(1000, 3)
X_test, y_test = generate_data(200, 3)

# 标准化数据
X_train_norm = normalize(X_train)
y_train_norm = normalize(y_train)
X_test_norm = normalize(X_test)
y_test_norm = normalize(y_test)

# 初始化RNN
input_size = 3
hidden_size = 64
output_size = 1
rnn = RNN(input_size, hidden_size, output_size)


# 训练函数
def train(rnn, X, y, epochs):
    for epoch in range(epochs):
        total_loss = 0
        for i in range(len(X)):
            inputs = [X[i][t].reshape(-1, 1) for t in range(3)]
            target = y[i]

            # 前向传播
            output, _ = rnn.forward(inputs)

            # 计算损失
            loss = np.sum((output - target) ** 2)
            total_loss += loss

            # 反向传播
            d_y = 2 * (output - target)
            rnn.backward(d_y)

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss / len(X)}")


# 初始化RNN
input_size = 3
hidden_size = 64
output_size = 1
rnn = RNN(input_size, hidden_size, output_size)

# 训练模型
train(rnn, X_train_norm, y_train_norm, epochs=100)


# 评估函数
def evaluate(rnn, X, y):
    total_loss = 0
    for i in range(len(X)):
        inputs = [X[i][t].reshape(-1, 1) for t in range(3)]
        target = y[i]

        output, _ = rnn.forward(inputs)
        loss = np.sum((output - target) ** 2)
        total_loss += loss

    return total_loss / len(X)


# 评估模型
test_loss = evaluate(rnn, X_test_norm, y_test_norm)
print(f"Test Loss: {test_loss}")


# 进行预测
def predict(rnn, X):
    inputs = [X[t].reshape(-1, 1) for t in range(3)]
    output, _ = rnn.forward(inputs)
    return output


# 示例预测
sample_data = X_test_norm[0]
prediction = predict(rnn, sample_data)
print(f"Sample input: {X_test[0]}")
print(f"Normalized prediction: {prediction[0][0]}")

# 反标准化预测结果
mean_y = np.mean(y_train)
std_y = np.std(y_train)
denormalized_prediction = prediction[0][0] * std_y + mean_y
print(f"Denormalized prediction: {denormalized_prediction}")
print(f"Actual value: {y_test[0][0]}")

创作不易,烦请各位观众老爷给个三连,小编在这里跪谢了!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[紧急!!!]20240719全球Windows10/11蓝屏问题,CrowdStrike导致的错误解决方案

文章目录 前言一、CrowdStrike是什么?二、PC解决方式(网路上大神的方式,虚拟机测试过)1.Windows PC 上 CrowdStrike BSOD 问题的官方解决方法:2.阻止CrowdStrick启动-命令行法3.阻止CrowdStrick启动-注册表法 三、AWS …

基于Matlab的数据可视化

基于Matlab的数据可视化 一、二维图形的绘制(一)基本图形函数(1)plot函数(2)fplot函数(3)其他坐标系的二维曲线 (二)图形属性设置(1)线…

对某次应急响应中webshell的分析

文章前言 在之前处理一起应急事件时发现攻击者在WEB应用目录下上传了webshell,但是webshell似乎使用了某种加密混淆手法,无法直观的看到其中的木马连接密码,而客户非要让我们连接webshell来证实此文件为后门文件且可执行和利用(也是很恼火&a…

数据结构与算法04二叉树|二叉排序树|AVL树

目录 一、二叉树(binary tree) 1、二叉树常见术语 2、二叉树常用的操作 2.1、初始化:与链表十分相似,先创建节点,然后构造引用/指针关系. 2.2、插入和删除操作 3、常见二叉树类型 3.1、满二叉树 3.2、完全二叉树(complete b…

跳跃游戏Ⅱ - vector

55. 跳跃游戏 - 力扣&#xff08;LeetCode&#xff09; class Solution { public:bool canJump(vector<int>& nums) {int n nums.size();int reach 0;for(int i 0; i < n; i){if(i > reach){return false;}reach max(inums[i], reach);}return true;} }; …

SpringBoot3 + Vue3 学习 Day 2

登入接口 和 获取用户详细信息的开发 学习视频登入接口的开发1、登入主逻辑2、登入认证jwt 介绍生成 JWT① 导入依赖② 编写代码③ 验证JWT 登入认证接口的实现① 导入 工具类② controller 类实现③ 存在的问题及优化① 编写拦截器② 注册拦截器③ 其他接口直接提供服务 获取用…

JVM(day4)类加载机制

类加载过程 加载 通过一个类的全限定名来获取定义此类的二进制字节流。 将这个字节流所代表的静态存储结构转化为方法区的运行时数据结构。 在内存中生成一个代表这个类的java.lang.Class对象&#xff0c;作为方法区这个类的各种数据的访问入口。 验证 文件格式验证 元数…

LeetCode做题记录(第二天)647. 回文子串

题目&#xff1a; 647. 回文子串 标签&#xff1a;双指针 字符串 动态规划 题目信息&#xff1a; 思路一&#xff1a;暴力实现 我们直接for套for分割成一个个子串再判断&#xff0c;如果子串是回文子串&#xff0c;就1&#xff0c;最后得出结果 代码实现&#xff1a; cl…

C语言实例-约瑟夫生者死者小游戏

问题&#xff1a; 30个人在一条船上&#xff0c;超载&#xff0c;需要15人下船。于是人们排成一队&#xff0c;排队的位置即为他们的编号。报数&#xff0c;从1开始&#xff0c;数到9的人下船&#xff0c;如此循环&#xff0c;直到船上仅剩15人为止&#xff0c;问都有哪些编号…

Missing script:‘dev‘

场景&#xff1a; npm run dev 原因&#xff1a;没有安装依赖&#xff0c;可用镜像安装&#xff08;详见下图ReadMe 蓝色字体&#xff09;&#xff0c;没安装依赖可从package-lock.json文件是否存在看出&#xff0c;存在则有依赖 解决&#xff1a;

KMP算法(算法篇)

算法之KMP算法 KMP算法 概念&#xff1a; KMP算法是用于解决字符串匹配的问题的算法&#xff0c;也就是有一个文本串和一个模式串&#xff0c;求解这个模式串是否在文本串中出现或者匹配。相对于暴力求解&#xff0c;KMP算法使用了前缀表来进行匹配&#xff0c;充分利用了之…

【Vue3】从零开始编写项目

【Vue3】从零开始编写项目 背景简介开发环境开发步骤及源码总结 背景 随着年龄的增长&#xff0c;很多曾经烂熟于心的技术原理已被岁月摩擦得愈发模糊起来&#xff0c;技术出身的人总是很难放下一些执念&#xff0c;遂将这些知识整理成文&#xff0c;以纪念曾经努力学习奋斗的…

神经网络模型实现(训练、测试)

目录 一、神经网络骨架&#xff1a;二、卷积操作&#xff1a;三、卷积层&#xff1a;四、池化层&#xff1a;五、激活函数&#xff08;以ReLU为例&#xff09;&#xff1a;六、模型搭建&#xff1a;七、损失函数、梯度下降&#xff1a;八、模型保存与加载&#xff1a;九、模型训…

Linux下安装JDK、Tomact、MySQL以及Nginx的超详细步骤

目录 1、为什么安装这些软件 2、安装软件的方式 3、安装JDK 3.1 下载Linux版本的JDK 3.2 将压缩包拖拽到Linux系统下 3.3 解压jdk文件 3.4 修改文件夹名字 3.5 配置环境变量 4、安装Tomcat 4.1 下载Tomcat 4.2 将Tomcat放入Linux系统并解压&#xff0c;步骤如上面的…

MenuToolButton自绘控件,带下拉框的QToolButton,附源码

MenuToolButton自绘控件&#xff0c;带下拉框的QToolButton 效果 下拉样式可自定义 跟随QToolButton的Qt::ToolButtonStyle属性改变图标文字样式 使用示例 正常UI文件创建QToolButton然后提升&#xff0c;或者直接代码创建都可以。 // 创建一个 QList 对象来存储 QPixm…

JDK、JRE、JVM的区别java的基本数据类型

说一说JDK、JRE、JVM的区别在哪&#xff1f; JDK&#xff1a; Java Delopment kit是java工具包&#xff0c;包含了编译器javac&#xff0c;调试器&#xff08;jdb&#xff09;以及其他用于开发和调试java程序的工具。JDK是开发人员在开发java应用程序时候所需要的的基本工具。…

10道JVM经典面试题

1、 JVM中&#xff0c;new出来的对象是在哪个区&#xff1f; 2、 说说类加载有哪些步骤&#xff1f; 3、 JMM是什么&#xff1f; 4、 说说JVM内存结构&#xff1f; 5、 MinorGC和FullGC有什么区别&#xff1f; 6、 什么是STW? 7、 什么情况下会发生堆/栈溢出&#xff1f…

【高中数学/对数函数】log_x_x+1与(x+1)/x,log_x+1_x与x/(x+1)的图线有着惊人的相似性

【图像】 褐线与蓝线&#xff0c;黄线与绿线&#xff0c;只是像左右平移了一样。 【生成图像的代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head>…

大模型学习笔记十二:AI产品部署

文章目录 一、如何选择GPU和云服务器厂商&#xff0c;追求最高性价比1&#xff09;根据场景选择GPU2&#xff09;训练或微调所需显卡&#xff08;以Falcon为例子&#xff09;3&#xff09;服务器价格计算器 二、全球大模型了解1&#xff09;llm所有模型2&#xff09;模型综合排…

基于Python+Django,开发的一个在线教育系统

一、项目简介 使用Python的web框架Django进行开发的一个在线教育系统&#xff01; 二、所需要的环境与组件 Python3.6 Django1.11.7 Pymysql Mysql pure_pagination DjangoUeditor captcha xadmin crispy_forms 三、安装 1. 下载项目后进入项目目录cd Online-educ…