循环神经网络笔记

news2024/12/23 18:45:15

循环神经网络学习

RNN训练方法–BPTT

BPTT (Backpropagation Through Time),这是一种用于训练循环神经网络(RNNs)的算法。由于 RNNs 能够处理序列数据,并且在每个时间步上都有内部状态,因此需要一种特殊的方法来计算梯度并更新权重。BPTT 就是这样一种方法,它扩展了标准的反向传播算法,以适应时间序列上的依赖关系。

BPTT 的基本原理

  1. 前向传播:首先,输入序列通过网络进行前向传播。对于每个时间步,RNN 会根据当前的输入和前一时间步的状态计算出新的状态,并可能产生输出。
  2. 损失计算:一旦整个序列被处理完毕,就可以计算总损失。这个损失通常是基于模型在所有时间步上的预测与实际目标之间的差异。
  3. 反向传播:然后,从最后一个时间步开始,计算损失关于每个时间步上的权重的梯度,并将这些梯度回传到前面的时间步。这个过程可以看作是在时间维度上展开 RNN,然后像普通前馈神经网络一样应用反向传播。
  4. 权重更新:最后,使用累积的梯度信息来更新网络中的权重,以便减少下一次迭代时的损失。

RNN存在的问题

  1. 梯度消失/爆炸
    • 在训练过程中,当通过时间反向传播误差时,由于链式法则的应用,梯度会随着时间步长的增加而被反复乘以权重矩阵中的值。如果这些值小于1,则梯度会趋向于0(梯度消失),导致远离当前时间步的信息无法对更新产生影响;如果这些值大于1,则梯度会趋向于无穷大(梯度爆炸),可能导致数值不稳定。
    • 解决方案:使用梯度裁剪来防止梯度爆炸,以及引入更复杂的结构如LSTM或GRU来缓解梯度消失问题。
  2. 长期依赖问题
    • 由于梯度消失,RNN很难学习到远距离的时间依赖关系。也就是说,对于较早时间点的信息,RNN可能无法有效地将其与当前时间点的信息联系起来。
    • 解决方案:LSTM和GRU等架构通过引入记忆单元和门控机制来更好地捕捉长期依赖性。
  3. 计算效率低
  4. 内存消耗
  5. 难以训练
    • 需要采用适当的初始化方法(如Xavier/Glorot初始化)、正则化技术(如Dropout)和优化算法(如Adam)可以帮助改善训练过程。
  6. 固定的上下文窗口

LSTM

长短期记忆网络(Long Short-Term Memory,LSTM)是一种特别设计来解决长期依赖问题的循环神经网络(RNN)架构。

LSTM的核心思想是引入了称为**“细胞状态”**(cell state)的概念,该状态可以在时间步长中被动态地添加或删除信息。能够有效地记住信息并控制何时让信息通过或忘记。

LSTM的结构:

在这里插入图片描述

LSTM有通过精心设计的称作“门”的结构来去除或者增加信息到细胞状态的能力。
门是一种让信息选择式通过的方法。他们包含一个sigmoid神经网络层和一个pointwise乘法操作。

Sigmoid层输出0到1之间的数值,描述每个部分有多少量可以通过。
0代表“不许任何量通过”
1代表“允许任何量通过”
LSTM 拥有三个门,来保护和控制细胞状态。

门控机制

LSTM 单元由几个部分组成,主要包括:

  • 输入门 (Input Gate):决定多少新输入的信息会被存储到单元状态中。
  • 遗忘门 (Forget Gate):决定哪些信息应该从单元状态中被丢弃。
  • 输出门 (Output Gate):基于当前单元状态和输入信息,决定输出什么值。
  • 单元状态 (Cell State):这是LSTM的核心,它是一个贯穿整个序列的信息传递通道,允许信息在整个序列中流动而不受太多干扰。

工作原理

  1. 遗忘门:首先,LSTM会决定要忘记哪些信息。这一步通过一个sigmoid层完成,该层接收上一时间步的隐藏状态 ht−1 和当前输入 xt,然后输出一个介于0和1之间的数字向量。这个向量中的每个元素表示对应位置上的信息被保留的程度(1表示完全保留,0表示完全忘记)。

    在这里插入图片描述

  2. 输入门:接下来,LSTM需要确定新的信息如何加入到单元状态中。这分为两个步骤:

    • 一个新的候选值向量 C~t 通过tanh层生成。
    • 一个sigmoid层(称为输入门)决定了这些候选值中的多少将被添加到单元状态中。

    在这里插入图片描述

  3. 更新单元状态:旧的单元状态 Ct−1 乘以遗忘门的输出,然后加上输入门的结果(经过tanh激活的新候选值与输入门输出的逐元素相乘)。这样就得到了更新后的单元状态 Ct。

    在这里插入图片描述

  4. 输出门:最后,LSTM需要确定要输出什么。这同样分为两步:

    • 一个sigmoid层决定单元状态的哪一部分将被输出。
    • 单元状态经过tanh层处理(将值缩放到-1到1之间),然后与输出门的结果进行逐元素相乘,得到最终的隐藏状态 ht。

    在这里插入图片描述

W f , W i , W o , W c 是权重矩阵 , b f , b i , b 0 , b c 是偏置 σ 表示 s i g m o i d 激活函数, ∗ 表示逐元素乘法 遗忘门: f t = σ ( W f [ h t − 1 , x t ] + b f ) 输出门 : i t = σ ( W i [ h t − 1 , x t ] + b i ) 新候选值 : C ~ t = tanh ⁡ ( W c [ h t − 1 , x t ] + b c ) 更新单元状态 : C t = f t ∗ C t − 1 + i t ∗ C ~ t 输出门 : o t = σ ( W o [ h t − 1 , x t ] + b o ) 最终隐藏状态 : h y = o t ∗ tanh ⁡ ( C t ) \\W_f,W_i,W_o,W_c是权重矩阵,b_f,b_i,b_0,b_c是偏置 \\ \sigma表示sigmoid激活函数, *表示逐元素乘法 \\遗忘门:f_t = \sigma (W_f[h_{t-1},x_t]+b_f) \\输出门:i_t = \sigma(W_i[h_{t-1},x_t]+b_i) \\新候选值:\tilde{C}_t = \tanh(W_c[h_t-1,x_t]+b_c) \\更新单元状态:C_t = f_t*C_{t-1}+i_t*\tilde{C}_t \\输出门:o_t = \sigma(W_o[h_{t-1},x_t]+b_o) \\最终隐藏状态:h_y = o_t*\tanh(C_t) Wf,Wi,Wo,Wc是权重矩阵,bf,bi,b0,bc是偏置σ表示sigmoid激活函数,表示逐元素乘法遗忘门:ft=σ(Wf[ht1,xt]+bf)输出门:it=σ(Wi[ht1,xt]+bi)新候选值:C~t=tanh(Wc[ht1,xt]+bc)更新单元状态:Ct=ftCt1+itC~t输出门:ot=σ(Wo[ht1,xt]+bo)最终隐藏状态:hy=ottanh(Ct)

代码实现

原生代码

import numpy as np
import torch


class LSTM:
    def __init__(self, input_size, hidden_size, output_size):
        # 参数:词向量大小,隐藏层大小, 输出类别
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # 初始化权重,偏置,把结构的W,U拼接在一起
        self.W_f = np.random.rand(hidden_size, input_size+hidden_size)
        self.b_f = np.random.rand(hidden_size)

        self.W_i = np.random.rand(hidden_size, input_size + hidden_size)
        self.b_i = np.random.rand(hidden_size)

        self.W_c = np.random.rand(hidden_size, input_size + hidden_size)
        self.b_c = np.random.rand(hidden_size)

        self.W_o = np.random.rand(hidden_size, input_size + hidden_size)
        self.b_o = np.random.rand(hidden_size)

        # 输出层
        self.W_y = np.random.rand(output_size, hidden_size)
        self.b_y = np.random.rand(output_size)

    def tanh(self, x):
        return np.tanh(x)

    def sigmoid(self, x):
        return 1/(1+np.exp(-x))

    def forward(self, x):
        # 初始化隐藏状态
        h_t = np.zeros((self.hidden_size,))
        # 初始化细胞状态
        c_t = np.zeros((self.hidden_size,))

        h_states = []  # 存储每一个时间步的隐藏状态
        c_states = []  # 存储每一个时间步的细胞状态

        for t in range(x.shape[0]):
            x_t = x[t]  # 获取当前时间步的输入(一个词向量)
            # 将x_t和h_t进行垂直方向拼接
            x_t = np.concatenate([x_t, h_t])

            # 遗忘门 "dot"迷茫中,这里是点积的效果,(5,7)点积(7,)得到的是(5,)
            f_t = self.sigmoid(np.dot(self.W_f, x_t) + self.b_f)

            # 输出门
            i_t = self.sigmoid(np.dot(self.W_i, x_t) + self.b_i)

            # 候选细胞状态
            c_hat_t = self.tanh(np.dot(self.W_c, x_t) + self.b_c)
            # 更新细胞状态, "*"对应位置直接相乘
            c_t = f_t * c_t + i_t * c_hat_t

            # 输出门
            o_t = self.sigmoid(np.dot(self.W_o, x_t) + self.b_o)
            # 更新隐藏状态
            h_t = o_t * self.tanh(c_t)

            # 保存时间步的隐藏状态和细胞状态
            h_states.append(h_t)
            c_states.append(c_t)

        # 输出层,分类类别
        y_t = np.dot(self.W_y, h_t) + self.b_y
        output = torch.softmax(torch.tensor(y_t), dim=0)

        return np.array(h_states), np.array(c_states), output


# 数据输入
x = np.random.rand(3, 2)
hidden_size = 5

# 实例化模型
lstm = LSTM(2, hidden_size, 6)

h_states, c_states, output = lstm.forward(x)

print("h_states:", h_states)
print("h_states_shape:", h_states.shape)
print("c_states:", c_states)
print("c_states_shape:", c_states.shape)
print("output:", output)
print("output_shape:", output.shape)


基于Pytorch API的代码实现

import torch
import torch.nn as nn


# 定义模型
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        # 调用接口
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化隐藏状态和细胞
        h0 = torch.zeros(1, x.size(1), self.hidden_size)
        c0 = torch.zeros(1, x.size(1), self.hidden_size)

        # 前向传播
        # out所有时间步的输出结果,state最后时间步的隐藏状态和细胞状态
        out, state = self.lstm(x, (h0, c0))

        out = out[-1]  # 取最后一个时间步的输出
        output = self.fc(out)

        return output


# 模型参数
seq_size, batch_size, input_size = 5, 4, 3
hidden_size, output_size = 6, 7

model = LSTMModel(input_size, hidden_size, output_size)

# 模拟数据
x = torch.randn(seq_size, batch_size, input_size)

output = model(x)

print(output)
print(output.shape)

序列池化(平均池化和最大池化)

import torch
import torch.nn as nn

# 平均池化
# 输入数据
input_data = torch.randn(2, 3, 4)

# 调用平均池化
avg_pool = nn.AdaptiveAvgPool1d(1)

# 调整形状去匹配池化的输入
input_data = input_data.permute(0, 2, 1)  # (batch,seq,dim)->(batch,dim,seq)

output = avg_pool(input_data)

print(output)
print(output.shape)


# 最大池化
# 输入数据
input_data = torch.randn(2, 3, 4)

# 调用平均池化
max_pool = nn.AdaptiveMaxPool1d(1)

# 调整形状去匹配池化的输入
input_data = input_data.permute(0, 2, 1)  # (batch,seq,dim)->(batch,dim,seq)

output = max_pool(input_data)

print(output)
print(output.shape)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2167828.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

南京自闭症寄宿学校:打造温馨的第二家

南京自闭症寄宿学校的愿景与广州星贝育园的温馨实践 在探讨自闭症儿童教育的广阔领域中,寄宿制学校以其独特的优势,为这些特殊的孩子提供了全方位的支持与关怀,致力于打造一个温馨如家的第二生活环境。虽然本文的主题是围绕南京自闭症寄宿学…

Chirp通过Sui让IoT世界变得更简单

据估计,未来十年内,联网设备的数量将增长到近400亿台。无论是追踪共享出行车辆的移动、改善食品追溯性、监控制造设施,还是保障家庭安全,物联网 ( Internet of Things,IoT) 对企业和消费者来说都已经成为一项关键技术。…

刷题学习日记 (1) - SWPUCTF

写这篇文章主要是想看看自己一个下午能干啥,不想老是浪费时间了,所以刷多少题我就会写多少题解,使用nss随机刷题,但是今天下午不知道为啥一刷都是SWPUCTF的。 [SWPUCTF 2021 新生赛]gift_F12 控制台ctrlf搜索flag即可&#xff0…

什么是竞争条件?

竞争条件,简单来说就是多个进程同时访问同一个共享资源,导致出现预期结果以外的错误的情况。 出现竞争条件的本质原因是cpu对程序的调度是没有特定规律的,某一时刻cpu处理哪个进程是不确定的。 简单写一个测试程序,先需要子进程和…

ubuntu安装emqx

目录 1.预先下载好emqx压缩包 2.使用tar命令解压 3.进入bin目录 5.放开访问端口18083 6.从通过ip地址访问emqx后台 7.默认用户名密码为admin/public 8.登录后台 9.资源包绑定在此博文可自取 1.预先下载好emqx压缩包 2.使用tar命令解压 sudo tar -xzvf emqx-5.0.8-el8-…

手机轻松解压 RAR 文件指南

手机通常不直接支持 RAR 文件打开,主要有以下几个原因。首先,手机操作系统的设计初衷并非为了处理各种复杂的压缩文件格式。 大多数手机内置的文件管理器主要侧重于管理手机内部存储和常见的文件类型,如图片、音频、视频等。对于像 RAR 这样…

【UR #1】外星人(dp思维技巧)

考虑去除后效性,常用方法排序状态可以直接以答案为状态来判断合法性考虑转移方向,向后转移,选与不选来定向答案 f[i][j]表示前i个数答案为j的方案数 不选i 则加上f[i][j] 的方案数 * (n-i),ai可以在后面随便选。 选…

Python 课程20-Scikit-learn

前言 Scikit-learn 是 Python 中最流行的机器学习库之一,它提供了多种用于监督学习和无监督学习的算法。Scikit-learn 的特点是简单易用、模块化且具有高效的性能。无论是初学者还是专业开发者,都可以借助它进行快速原型设计和模型开发。 在本教程中&a…

为何专利对企业创新与竞争至关重要?

在当今这个技术飞速发展的时代,每一个创新的火花都可能成为推动行业进步的关键力量。然而,创新并非一蹴而就,它需要时间、资金与智慧的共同投入,更需要一套完善的保护机制来确保其成果不被轻易窃取或模仿。这一重任,便…

WebPage-Bootstrap框架(container类,container-fluid类,栅格系统)

1.Bootstrap Bootstrap为页面内容和栅格系统包裹了一个.container容器,框架预先定义类 1.1container类 响应式布局容器的宽度 手机-小于768px 宽度设置100%; 平板-大于等于768px 设置宽度为750px 桌面显示器-大于等于992px 设置宽度 970px 大屏幕显…

医院排班|医护人员排班系统|基于springboot医护人员排班系统设计与实现(源码+数据库+文档)

医护人员排班系统目录 目录 基于springboot医护人员排班系统设计与实现 一、前言 二、系统功能设计 三、系统实现 医护类型管理 排班类型管理 科室信息管理 医院信息管理 医护信息管理 四、数据库设计 1、实体ER图 2、具体的表设计如下所示: 五、核心代码…

“AI+Security”系列第3期(五):AI技术在网络安全领域的本地化应用与挑战

近日,由安全极客、Wisemodel 社区、InForSec 网络安全研究国际学术论坛和海升集团联合主办的“AI Security”系列第 3 期技术沙龙:“AI 安全智能体,重塑安全团队工作范式”活动顺利举行。此次活动吸引了线上线下超过千名观众参与。 在活动中…

shell中对xargs命令传参进行编辑

以文件解压为例,将当前路径下的所有gz文件解压到同名的log文件中,解压命令如下所示: ls *.gz| xargs -n 1 -P 4 -I {} bash -c zcat "{}" > $(echo "{}" | sed "s/gz$/log/g") 执行结果如下图所示&#x…

mamba-yolo模型的深度学习环境配置

本文将介绍如何配置目标检测模型mamba-yolo的深度学习环境 1. 环境要求 Python > 3.9 (本文使用python-3.11) CUDA > 11.6 (本文使用CUDA-11.8) Pytorch > 1.12.1 (本文使用torch-2.4.0) Linu…

【C++】STL标准模板库容器——set

🦄个人主页:修修修也 🎏所属专栏:C ⚙️操作环境:Visual Studio 2022 目录 📌关联式容器set(集合)简介 📌set(集合)的使用 🎏set(集合)的模板参数列表 🎏set(集合)的构造函数 🎏set(集合)的迭代…

JavaScript异步编程:async、await的使用

async 和 await 是在 ECMAScript 2017 (ES7) 中引入的特性,用于处理异步操作。它们允许你以一种更加简洁和同步的方式来编写异步代码。 async 函数表示它会返回一个 Promise,而 await 关键字用于等待一个 Promise 解决。 关于 promise 的详细介绍&#…

蜂窝物联网全网通sim卡切网技术方案软硬件实现教程(设备根据基站信号质量自动切网)

01 物联网系统中为什么要使用三合一卡 三合一卡为用户解决了单一运营商网络无法全覆盖的缺陷,避免再次采购的经济成本以及时间成本和因没有信号设备停止工作造成的损失,保证仅需一次采购并提高设备工作效率和入网活跃度。例如下面地区的设备&#xff0…

Spring Web MVC课后作业

目录 1.加法计算器 2.⽤户登录 3.留⾔板 1.加法计算器 (1)需求分析 加法计算器功能, 对两个整数进⾏相加, 需要客⼾端提供参与计算的两个数, 服务端返回这两个整数计算 的结果。 (2)接⼝定义 请求路径: calc/sum 请…

Java框架学习(mybatis)(01)

简介:以本片记录在尚硅谷学习ssm-mybatis时遇到的小知识 详情移步:想参考的朋友建议全部打开相互配合学习! 官方文档: MyBatis中文网https://mybatis.net.cn/index.html 学习视频: 067-mybatis-介绍和对比_哔哩哔…

Linux本地服务器搭建开源监控服务Uptime Kuma与远程监控实战教程

文章目录 前言**主要功能**一、前期准备本教程环境为:Centos7,可以跑Docker的系统都可以使用本教程安装。本教程使用Docker部署服务,如何安装Docker详见: 二、Docker部署Uptime Kuma三、实现公网查看网站监控四、使用固定公网地址…