【深度学习基础模型】双向循环神经网络(Bidirectional Recurrent Neural Networks, BiRNN)详细理解并附实现代码。

news2024/11/17 19:41:38

【深度学习基础模型】双向循环神经网络(Bidirectional Recurrent Neural Networks, BiRNN)

【深度学习基础模型】双向循环神经网络(Bidirectional Recurrent Neural Networks, BiRNN)详细理解并附实现代码。


文章目录

  • 【深度学习基础模型】双向循环神经网络(Bidirectional Recurrent Neural Networks, BiRNN)
  • 1.双向循环神经网络(Bidirectional Recurrent Neural Networks, BiRNN)原理详解
    • 1.1 BiRNN 概述
    • 1.2 BiRNN 的工作原理
    • 1.3 双向 LSTM 和双向 GRU
    • 1.4 BiRNN 的应用场景
  • 2.Python 实现双向 LSTM(BiLSTM)的实例
    • 2.1BiLSTM 实现及应用实例
    • 2.2代码解释
  • 3.总结


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=650093

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

1.双向循环神经网络(Bidirectional Recurrent Neural Networks, BiRNN)原理详解

在这里插入图片描述

1.1 BiRNN 概述

双向循环神经网络(BiRNN)是 RNN 的扩展版本,它不仅利用序列的过去信息(正向传递),还利用未来信息(反向传递)。这种双向机制使得 BiRNN 在需要全局上下文理解的任务中更加有效,如自然语言处理和序列标注任务。

1.2 BiRNN 的工作原理

在单向 RNN 中,隐藏状态 ht只依赖于前一个时刻的状态 ht-1。在 BiRNN 中,网络会在两个方向上处理输入序列:

  • **前向传播:**从序列的第一个时间步到最后一个时间步。
  • **反向传播:**从序列的最后一个时间步到第一个时间步。

最终的隐藏状态是这两个方向的输出组合在一起。BiRNN 的公式为:

  • 前向 RNN:

h t → = f ( W x x t + W h h t − 1 → + b h ) \overrightarrow{h_t}=f(W_xx_t+W_h\overrightarrow{h_{t-1}}+b_h) ht =f(Wxxt+Whht1 +bh)

  • 前向 RNN:

h t ← = f ( W x x t + W h h t − 1 ← + b h ) \overleftarrow{h_t}=f(W_xx_t+W_h\overleftarrow{h_{t-1}}+b_h) ht =f(Wxxt+Whht1 +bh)

  • 最终输出:

h t = [ h t → ; h t ← ] h_t=[\overrightarrow{h_t};\overleftarrow{h_t}] ht=[ht ;ht ]

其中 f f f是激活函数, h t → \overrightarrow{h_t} ht h t ← \overleftarrow{h_t} ht 分别表示前向和后向的隐藏状态。

1.3 双向 LSTM 和双向 GRU

双向 LSTM(BiLSTM)和双向 GRU(BiGRU)基于 LSTM 和 GRU,增加了对未来信息的捕捉能力。双向 LSTM 通过在正向和反向同时应用 LSTM 网络结构,进一步提升了处理长序列的能力,适用于机器翻译、文本生成等任务

  • BiLSTM:在 LSTM 的基础上,正向和反向的 LSTM 共同作用,整合全局信息。
  • BiGRU:与 BiLSTM 类似,使用 GRU 作为网络的基本单元。

1.4 BiRNN 的应用场景

双向网络尤其适用于需要对整个序列进行全面理解的任务:

  • 自然语言处理 (NLP): 如命名实体识别、序列标注、文本分类、语言建模等。
  • 语音识别: 通过捕捉前后音素之间的关联,提升语音识别的精度。
  • 机器翻译: 可以通过考虑上下文,生成更加精准的翻译。

2.Python 实现双向 LSTM(BiLSTM)的实例

我们使用 PyTorch 来实现一个基于 BiLSTM 的文本分类模型。

2.1BiLSTM 实现及应用实例

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 构造简单的示例数据集
# 假设有两个类别的句子,分别标注为 0 和 1
X = [
    [1, 2, 3, 4],     # "I love machine learning"
    [5, 6, 7, 8],     # "deep learning is great"
    [1, 9, 10, 11],   # "I hate spam emails"
    [12, 13, 14, 15]  # "phishing attacks are bad"
]
y = [0, 0, 1, 1]  # 标签

# 转换为 Tensor 格式
X = torch.tensor(X, dtype=torch.long)
y = torch.tensor(y, dtype=torch.long)

# 定义数据集和数据加载器
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 定义 BiLSTM 模型
class BiLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(BiLSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding = nn.Embedding(input_size, hidden_size)  # 嵌入层
        self.bilstm = nn.LSTM(hidden_size, hidden_size, num_layers, 
                              batch_first=True, bidirectional=True)  # 双向 LSTM
        self.fc = nn.Linear(hidden_size * 2, output_size)  # 全连接层, *2 是因为双向 LSTM
        
    def forward(self, x):
        # 初始化隐藏状态和单元状态
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)
        
        # 嵌入层
        out = self.embedding(x)
        
        # 通过双向 LSTM
        out, _ = self.bilstm(out, (h0, c0))
        
        # 取最后一个时间步的输出(双向,前向和后向拼接)
        out = out[:, -1, :]
        
        # 全连接层进行分类
        out = self.fc(out)
        return out

# 模型参数
input_size = 16  # 假设词汇表有 16 个词
hidden_size = 8  # 隐藏层维度
output_size = 2  # 输出为二分类
num_layers = 1   # LSTM 层数

# 创建模型
model = BiLSTMModel(input_size, hidden_size, output_size, num_layers)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 20
for epoch in range(num_epochs):
    for data, labels in dataloader:
        # 前向传播
        outputs = model(data)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch+1) % 5 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 测试模型
with torch.no_grad():
    test_sentence = torch.tensor([[1, 2, 3, 4]])  # 测试句子 "I love machine learning"
    prediction = model(test_sentence)
    predicted_class = torch.argmax(prediction, dim=1)
    print(f'Predicted class: {predicted_class.item()}')

2.2代码解释

1.定义 BiLSTM 模型:

  • self.embedding = nn.Embedding(input_size, hidden_size):将输入的单词索引转换为高维向量表示。
  • self.bilstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, bidirectional=True):定义双向 LSTM 层,bidirectional=True 表示双向 LSTM。
  • self.fc = nn.Linear(hidden_size * 2, output_size):全连接层将双向 LSTM 的输出映射为分类输出,隐藏层维度乘以 2 是因为有前向和后向两个方向的输出。

2.BiLSTM 的前向传播:

  • h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device):初始化 LSTM 的隐藏状态和细胞状态。由于是双向的,所以乘以 2。
  • out, _ = self.bilstm(out, (h0, c0)):通过双向 LSTM 层,out 是每个时间步的输出。
  • out = out[:, -1, :]:取最后一个时间步的输出作为最终输出。
  • out = self.fc(out):通过全连接层进行分类。

3.数据集与加载器:

  • 构建一个简单的二分类文本数据集,将其转换为 PyTorch 的 TensorDatasetDataLoader,方便训练模型。

4.训练与测试:

  • 使用 Adam 优化器和交叉熵损失函数训练模型,在每 5 个 epoch 打印一次损失。
  • 测试阶段通过输入测试句子,输出模型的预测分类结果。

3.总结

双向 RNN(BiRNN)以及其扩展的双向 LSTM(BiLSTM)和双向 GRU(BiGRU)在自然语言处理、语音识别和机器翻译等领域表现出色。通过捕捉序列的全局上下文信息,双向网络能够更好地理解数据序列。我们通过 Python 和 PyTorch 实现了一个 BiLSTM 模型,展示了它在文本分类中的应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2169705.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 Llama 3.1 和 Qdrant 构建多语言医疗保健聊天机器人的步骤

长话短说: 准备好深入研究: 矢量存储的复杂性以及如何利用 Qdrant 进行高效数据摄取。掌握 Qdrant 中的集合管理以获得最佳性能。释放上下文感知响应的相似性搜索的潜力。精心设计复杂的 LangChain 工作流程以增强聊天机器人的功能。将革命性的 Llama …

虚幻蓝图Ai随机点移动

主要函数: AI MoveTo 想要AI移动必须要有 导航网格体边界体积 (Nav Mesh Bounds Volume) , 放到地上放大 , 然后按P键 , 可以查看范围 然后创建一个character类 这样连上 AI就会随机运动了 为了AI移动更自然 , 取消使用控制器旋转Yaw 取消角色移动组件 的 使用控制器所需的…

风扇模块(直流5V STM32)

目录 一、介绍 二、传感器原理 1.原理图 2.引脚描述 三、程序设计 main.c文件 fan.h文件 fan.c文件 四、实验效果 五、资料获取 项目分享 一、介绍 直流风扇(Fan),具有高转速、大风量、低噪音、低能耗和低震动的特点,有DC5V和12V两种型号可供…

【HarmonyOS】Web组件同步与异步数据获取

Web组件交互同步与异步获取数据的方式示例 【html测试文件】src/main/resources/rawfile/Page04.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><script>let isEnvSupported CSS in window &&…

云上攻防 | AWS中的常见 Cognito 配置错误

引言 AWS Cognito 是由亚马逊网络服务&#xff08;AWS&#xff09;提供的全托管服务&#xff0c;旨在简化 Web 和移动应用程序的用户认证和授权过程。它提供了一整套功能来处理用户注册、登录和用户管理&#xff0c;免去了开发人员从头构建这些功能的需求。 尽管本文讨论的攻…

8.11 矢量图层线要素单一符号使用二(箭头)

8.11 矢量图层线要素单一符号使用二(箭头)_qgis箭头-CSDN博客 目录 前言 箭头&#xff08;Arrow&#xff09; QGis设置线符号为箭头(Arrow) 二次开发代码实现 总结 前言 本章介绍矢量图层线要素单一符号中箭头&#xff08;Arrow&#xff09;的使用说明&#xff1a;文章中…

等保2.0数据库测评之达梦数据库测评

一、达梦数据库介绍 达梦数据库管理系统属于新一代大型通用关系型数据库&#xff0c;全面支持 ANSI SQL 标准和主流编程语言接口/开发框架。行列融合存储技术&#xff0c;在兼顾 OLAP 和 OLTP 的同时&#xff0c;满足 HTAP 混合应用场景。 本次安装环境为Windows10专业版操作…

华夏ERP3.1权限绕过代码审计

POC: /jshERP-boot/user/getAllList;.ico 调试分析poc: 这是poc很明显就是绕过权限&#xff0c;我们分析filter里面的代码。 Overridepublic void doFilter(ServletRequest request, ServletResponse response,FilterChain chain) throws IOException, ServletException {Htt…

基于Spring Boot的校园管理系统

目录 前言 功能设计 系统实现 获取源码 博主主页&#xff1a;百成Java 往期系列&#xff1a;Spring Boot、SSM、JavaWeb、python、小程序 前言 随着科学技术的飞速发展&#xff0c;社会的方方面面、各行各业都在努力与现代的先进技术接轨&#xff0c;通过科技手段来提高自…

使用API有效率地管理Dynadot域名,设置域名服务器(NS)

前言 Dynadot是通过ICANN认证的域名注册商&#xff0c;自2002年成立以来&#xff0c;服务于全球108个国家和地区的客户&#xff0c;为数以万计的客户提供简洁&#xff0c;优惠&#xff0c;安全的域名注册以及管理服务。 Dynadot平台操作教程索引&#xff08;包括域名邮箱&…

SQL Server的文本和图像函数

新书速览|SQL Server 2022从入门到精通&#xff1a;视频教学超值版_sql server 2022 出版社-CSDN博客 《SQL Server 2022从入门到精通&#xff08;视频教学超值版&#xff09;&#xff08;数据库技术丛书&#xff09;》(王英英)【摘要 书评 试读】- 京东图书 (jd.com) SQL Se…

【Python】Ajenti:轻量级、强大的服务器管理面板

在现代服务器管理中&#xff0c;管理员们经常需要通过命令行执行各种任务&#xff0c;这不仅耗时&#xff0c;而且对不熟悉 Linux 系统的用户来说并不友好。为了更高效地管理服务器、网站和应用&#xff0c;借助一个功能强大的管理面板是非常有必要的。Ajenti 就是这样一款轻量…

MySql数据库---判断函数,和窗口结合的函数,窗口函数

思维导图 判断函数 if(expr,v1,v2): 表达式结果为true返回v1,否则返回v2 ifnull(列名,dv): 列值为null返回dv,否则返回列值. nullif(expr1,expr2): 表达式1表达式2返回null,不等于返回表达式1的值. 窗口函数 作用: 可以为表新增一列,新增的列是什么取决于over()函数前面的函…

Spring Boot入门到精通:网上购物商城系统

第3章 系统分析 3.1 可行性分析 在系统开发之初要进行系统可行分析&#xff0c;这样做的目的就是使用最小成本解决最大问题&#xff0c;一旦程序开发满足用户需要&#xff0c;带来的好处也是很多的。下面我们将从技术上、操作上、经济上等方面来考虑这个系统到底值不值得开发。…

Cisco Secure Firewall Management Center Virtual 7.6.0 发布下载,新增功能概览

Cisco Secure Firewall Management Center Virtual 7.6.0 - 思科 Firepower 管理中心软件 Firepower Management Center Software for ESXi & KVM 请访问原文链接&#xff1a;https://sysin.org/blog/cisco-fmc-7/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留…

WPS中让两列数据合并的方法

有这样一个需求&#xff0c;就是把A列数据和B列数据进行合并&#xff08;空单元格略过&#xff09;具体实现效果如图下&#xff1a; 该如何操作呢&#xff1f; 首先在新的一列第一个单元格中输入公式"A1&B1" 然后回车&#xff0c;就出现了两列单元格数据合并的效…

人员个体检测、PID行人检测、行人检测算法样本

人员个体检测算法主要用于视频监控、安全防范、人流统计、行为分析等领域&#xff0c;通过图像识别技术来检测和识别视频或图像中的人员个体。这种技术可以帮助管理者实时监控人员活动&#xff0c;确保安全和秩序&#xff0c;提高管理效率。 一、技术实现 人员个体检测算法通常…

光耦——连接半导体创新的桥梁

半导体技术作为现代科技的重要支柱之一&#xff0c;在电子、通信、能源等领域都有着广泛的应用。而在半导体领域&#xff0c;光耦作为一种重要的光电器件&#xff0c;正以其独特的优势和广泛的应用领域&#xff0c;为半导体创新注入新的活力&#xff0c;成为连接半导体创新的桥…

IMX6UL开发板中断实验(三)

在上一节我们编写完成了中断驱动文件和中断驱动头文件&#xff0c;那么这一讲我们将继续中断实验 下面就是GPIO的中断设置&#xff0c;第一步要设置中断GPIO的触发方式&#xff0c;首先我们先看到寄存器&#xff0c;一共有GPIOx_ICR1和ICR2&#xff0c; 图如上&#xff0c;ICR1…

TortoiseGit 下载和安装

下载 1&#xff0c;下载路径 Download – TortoiseGit – Windows Shell Interface to Git 2&#xff0c;选择windows64的&#xff0c; 3&#xff0c;下载完成后 安装 1&#xff0c;双击运行&#xff0c;点击next 2&#xff0c;点击next 3&#xff0c;点击next 4&#xff0…