LSTM的变体

news2024/10/9 11:33:25

一、GRU

1、什么是GRU

门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了LSTM的结构,将遗忘门和输入门合并为一个更新门,并增加了一个重置门,同时合并了单元状态和隐藏状态,使得模型更加简洁,训练速度更快,且在性能上与LSTM相当。

2、GRU的核心

核心在于两个门:更新门(update gate)和重置门(reset gate)。更新门控制着从前一时刻的状态信息中保留多少到当前状态,而重置门决定着前一状态有多少信息被写入到当前的候选集中。这种结构使得GRU在处理长序列数据时能够更好地捕捉长期依赖关系,同时减少了模型参数,提高了计算效率。

3、GRU的应用

应用非常广泛,包括但不限于自然语言处理(NLP)、语音识别、图像处理等领域。在NLP领域,GRU可以用于语言建模、情感分析、机器翻译等任务;在语音识别领域,GRU可以用于语音信号的特征提取和识别;在图像处理领域,GRU可以用于图像分类、目标检测等任务。GRU的简洁性和效率使其在处理大规模序列数据时具有优势。

在选择GRU和LSTM时,通常考虑的因素包括任务的复杂性、数据集的大小以及训练资源。由于GRU参数更少,收敛速度更快,因此在需要快速迭代和实验时,GRU通常是首选。然而,在某些需要对复杂序列依赖关系进行建模的任务中,LSTM可能会表现得更好。

总的来说,GRU是一种强大的循环神经网络架构,它通过引入门控机制来控制信息流,有效地解决了传统RNN的梯度消失问题。GRU的简洁性和效率使其在多种序列建模任务中表现出色,成为了深度学习中处理时序数据的重要工具之一。

4、GRU的工作原理

5、手写代码实现

import numpy as np

class GRU():
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 初始化参数w和b
        self.W_z = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)
        self.b_z = np.zeros(self.hidden_size)
        # 重置门
        self.W_r = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)
        self.b_r = np.zeros(self.hidden_size)
        # 候选隐藏状态
        self.W_h = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)
        self.b_h = np.zeros(self.hidden_size)

    def tanh(self, x):
        return np.tanh(x)

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def forward(self, x):
        h_prev = np.zeros((self.hidden_size,))
        concat_input = np.concatenate([x, h_prev], axis=0)

        z_t = self.sigmoid(np.dot(self.W_z, concat_input) + self.b_z)
        r_t = self.sigmoid(np.dot(self.W_r, concat_input) + self.b_r)

        concat_reset_input = np.concatenate([x, r_t * h_prev], axis=0)

        h_hat_t = self.tanh(np.dot(self.W_h, concat_reset_input) + self.b_h)

        h_t = (1 - z_t) * h_prev + z_t * h_hat_t
        return h_t

二、BiLSTM

1、什么是BiSTM

BiSTM,即双向门控循环单元(Bidirectional Gated Recurrent Unit),是一种循环神经网络(RNN)的变体。它结合了前向和后向的GRU,能够同时处理过去和未来的信息,从而更好地捕捉序列数据中的上下文关系。

在BiSTM中,数据通过两个GRU网络进行处理:一个从左到右(前向),另一个从右到左(后向)。这两个网络的输出然后被拼接或相加,形成最终的特征表示,这个特征表示包含了序列的双向信息。这种结构特别适合于需要理解序列中前后文信息的任务,如文本分类、语音识别、命名实体识别(NER)等。

2、BiSTM的关键特点包括:

  1. 双向信息捕捉:BiSTM能够同时考虑序列中每个元素之前的和之后的上下文信息,这使得它在处理像文本这样的序列数据时非常有效,因为文本中词汇的含义往往受到其前后词汇的影响。

  2. 门控机制:BiSTM继承了GRU的门控机制,包括更新门和重置门,这些门控单元可以控制信息的流动,从而减少无效或噪声信息的干扰,并增强模型对重要信息的记忆能力。

  3. 应用广泛:BiSTM因其强大的序列处理能力而被广泛应用于各种领域,包括自然语言处理(NLP)、语音识别、时间序列分析等。

  4. 模型性能:在某些任务中,BiSTM能够提供比单向GRU或LSTM更好的性能,尤其是在需要捕捉长期依赖关系的任务中。

  5. 模型复杂度:由于BiSTM包含两个GRU网络,其模型参数和计算复杂度相对于单向GRU或LSTM会有所增加,但在很多情况下,这种增加是值得的,因为它能带来更准确的预测结果

 3、手写BiLSTM代码

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class LSTM(nn.Module):
    def __init__(self, vocab_size, target_size, input_size=512, hidden_size=512):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, input_size)
        self.mlp = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )
        self.lstm = nn.LSTM(hidden_size, hidden_size * 2, num_layers=3, batch_first=True, dropout=0.5)
        self.avg_lstm = nn.AdaptiveAvgPool1d(1)
        self.avg_linear = nn.AdaptiveAvgPool1d(1)
        self.out_linear = nn.Sequential(
            nn.Linear(hidden_size * 2 + hidden_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, target_size)
        )
        self.norm = nn.LayerNorm(hidden_size * 2)

    def forward(self, x, lengths):
        x = self.embedding(x)
        mlp = self.mlp(x)

        pached_embed = pack_padded_sequence(mlp, lengths, batch_first=True, enforce_sorted=False)
        lstm_out, _ = self.lstm(pached_embed)
        lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)

        lstm_out = self.norm(lstm_out)
        avg_lstm = self.avg_lstm(lstm_out.permute(0, 2, 1)).squeeze(-1)
        avg_linear = self.avg_linear(mlp.permute(0, 2, 1)).squeeze(-1)

        out = torch.cat([avg_lstm, avg_linear], dim=-1)
        return self.out_linear(out)

class BiLSTM(nn.Module):
    def __init__(self, input_size=512, hidden_size=512, output_size=512):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_forward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
        self.lstm_backward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)

    def forward(self, x):
        out_forward, _ = self.lstm_forward(x)
        out_backward, _ = self.lstm_backward(torch.flip(x, dims=[1]))
        out_backward = torch.flip(out_backward, dims=[1])
        combined_output = torch.cat([out_forward, out_backward], dim=-1)
        return combined_output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2198902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

判断回文 python

题目&#xff1a; 输入一个四位数&#xff0c;判断该数是否为回文数&#xff0c;回文数是指正序&#xff08;从左向右&#xff09;和倒序&#xff08;从右向左&#xff09;读都是一样的整数&#xff0c;比如1221。 代码法1&#xff1a; ninput() nint(n) if n<1000 or n&g…

微积分复习笔记 Calculus Volume 1 - 2.2 The Limit of a Function

2.2 The Limit of a Function - Calculus Volume 1 | OpenStax

中控自动化测试实战和实车智能驾驶业务解析

一.中控自动化测试流程及环境搭建 1.中控自动化测试流程 2.中控自动化测试环境的搭建 1.JDK环境配置 安装 Java安装包.生成java\bin jre\bin JAVA_HOME: java目录 c:\java path:%JAVA_HOME%\bin jre\bin 为了后面appium server GUI客户端中的环境配置 2.SDK 配置 pal…

怎么编辑图片?这5款工具教你快速编辑

怎么编辑图片&#xff1f;编辑图片是一项既具创意又实用的技能&#xff0c;它不仅能够提升图片的视觉效果&#xff0c;增强信息的传达力&#xff0c;还能激发无限的创作灵感。通过编辑图片&#xff0c;我们可以轻松调整色彩、添加文字、裁剪构图&#xff0c;甚至创造出令人惊叹…

Oxygen Forensic Detective 17.0 发布,新增功能概览

Oxygen Forensic Detective 17.0 发布&#xff0c;新增功能概览 Oxygen Forensic Detective Windows 17 Multilingual - 领先的一体化数字取证软件 digital forensic software 请访问原文链接&#xff1a;https://sysin.org/blog/oxygen-forensic-detective/&#xff0c;查看…

【学习笔记】SquareLine Studio安装教程(LVGL官方工具)

一.简介与导航&#xff1a; SquareLine Studio是由LVGL官方开发的一款UI设计工具&#xff0c;采用图形化进行界面UI设计&#xff0c;轻易上手。 SquareLine Studio官方网址&#xff1a;https://squareline.io/SquareLine Studio官方文档&#xff1a;https://docs.squareline.io…

车牌检测系统源码分享

车牌检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vision 研究…

3、Docker搭建MQTT及Spring Boot 3.x集成MQTT

一、前言 本篇主要是围绕着两个点&#xff0c;1、Docker 搭建单机版本 MQTT&#xff08;EMQX&#xff09;&#xff0c;2、Spring Boot 3.x 集成 MQTT&#xff08;EMQX&#xff09;&#xff1b; 而且这里的 MQTT&#xff08;EMQX&#xff09;的搭建也只是一个简单的过程&#x…

为什么现在的大学生很难真正学好LabVIEW编程?

学习LabVIEW编程对大学生来说可能存在以下挑战&#xff1a; 学习曲线陡峭&#xff1a;尽管LabVIEW提供直观的图形化编程环境&#xff0c;便于初学者入门&#xff0c;但要深入掌握其高级功能和复杂应用&#xff0c;仍需要投入大量时间和精力。随着学习的深入&#xff0c;概念和应…

CAN与CANFD的区别

CAN概念&#xff1a; CAN&#xff0c;全称为Controller Area Network&#xff0c;即控制器局域网络&#xff0c;是一种用于汽车电子系统中的串行通信协议。它由德国电气工程师协会&#xff08;Bosch&#xff09;在1983年开发&#xff0c;并在1986年正式推出。CAN协议主要用于汽…

牛客:Holding Two,Inverse Pair,Counting Triangles

Holding Two 题目描述 登录—专业IT笔试面试备考平台_牛客网 ​​运行代码 #include<bits/stdc.h> using namespace std; const int N3e45; string s1,s2; int main(){int n,m;cin>>n>>m;for(int i0;i<m;i){if(i&1){s10;s21;} else{s11;s20;} }fo…

架构师:Spring Cloud Gateway 的技术指南

1、简述 Spring Cloud Gateway 是 Spring Cloud 生态系统中的一个重要组件,作为微服务架构的 API 网关,它为路由、限流、安全、监控等功能提供了全面支持。相比传统的 Zuul 网关,Spring Cloud Gateway 使用了非阻塞的 WebFlux 框架,性能上有了显著提升,并且提供了更现代化…

BLE MESH学习2——自定义MESH网络架构思考

BLE MESH学习2——自定义MESH网络架构思考 基于对WCH CH582这款单片机的了解&#xff0c;其可以实现mesh配网、朋友节点、低功耗节点和中继节点的角色&#xff0c;基本功能无问题。在此基础上&#xff0c;考虑满足IoT需求的MESH架构设计&#xff0c;作为后续设计的“白皮书”。…

构建流媒体管道:利用 Docker 部署 Nginx-RTMP 从 FFmpeg RTMP 推流到 HLS 播放的完整流程

最近要实现一个类似导播台的功能&#xff0c;于是我先用 FFmpeg 实现一个参考对照的 Demo&#xff0c;我将其整理为一篇文章&#xff0c;方便后续大家或者和自己参考&#xff01; 1、软件工具介绍 本次部署相关软件 / 工具如下&#xff1a; FFmpeg&#xff1a;全称是 Fast Fo…

java脚手架系列1--模块化、多环境

之所以想写这一系列&#xff0c;是因为之前工作过程中有几次项目是从零开始搭建的&#xff0c;而且项目涉及的内容还不少。在这过程中&#xff0c;遇到了很多棘手的非业务问题&#xff0c;在不断实践过程中慢慢积累出一些基本的实践经验&#xff0c;认为这些与业务无关的基本的…

kkFileView 4.4.0最新版本发行版安装包部署及使用文档

kkFileView为文件文档在线预览解决方案&#xff0c;该项目使用流行的spring boot搭建&#xff0c;易上手和部署&#xff0c;基本支持主流办公文档的在线预览&#xff0c;如doc,docx,xls,xlsx,ppt,pptx,pdf,txt,zip,rar,图片,视频,音频等等 一. 下载&部署 下载最新发行包&…

网络层协议 --- IP

序言 在这篇文章中我们将介绍 IP协议&#xff0c;经过这篇文章的学习&#xff0c;我们就会了解运营商到底是如何为我们提供服务的以及平时我们所说的内网&#xff0c;公网到底又是什么&#xff0c;区别是什么&#xff1f; IP 地址的基本概念 1. IP 地址的定义 每一个设备接入…

Java中String类的常见操作Api

目录 String类的常见操作 1).int indexOf (char 字符) 2).int lastIndexOf(char 字符) 3).int indexOf(String 字符串) 4).int lastIndexOf(String 字符串) 5).char charAt(int 索引) 6).Boolean endWith(String 字符串) 7).int length() 8).boolean equals(T 比较对象) 9).b…

【2024】作为前端开发,必须掌握的 Vue3 的 5 个组合式 API 方法

引言&#xff1a;2024 你还不知道 Vue3 的 defineProps、defineEmits、defineExpose、defineOptions 和 defineSlots 吗&#xff1f;这几个 Vue3 组合式 API 方法不仅开发常用&#xff08;涉及组件通信、组件复用等&#xff09;&#xff0c;在面试中也是不可或缺的一部分&#…

SpringBoot实现电子文件签字+合同系统!

一、前言 二、项目源码及部署 1、项目结构及使用框架 2、项目下载及部署 三、功能展示 一、前言 今天公司领导提出一个功能,说实现一个文件的签字+盖章功能,然后自己进行了简单的学习,对文档进行数字签名与签署纸质文档的原因大致相同,数字签名通过使用计算机加密来验证 (…