Seq2Seq - GRU补充讲解

news2025/4/16 1:15:02

nn.GRU 是 PyTorch 中实现门控循环单元(Gated Recurrent Unit, GRU)的模块。GRU 是一种循环神经网络(RNN)的变体,用于处理序列数据,能够更好地捕捉长距离依赖关系。

⭐重点掌握输入输出部分输入张量:input、初始隐藏状态:h_0、输出张量:output、最终隐藏状态:h_n

nn.GRU 的参数

nn.GRU 的完整定义如下:

torch.nn.GRU(
    input_size,
    hidden_size,
    num_layers=1,
    bias=True,
    batch_first=False,
    dropout=0.0,
    bidirectional=False
)
1. input_size
  • 类型int

  • 含义:输入特征的维度。

  • 解释:假设输入序列的形状为 [batch_size, seq_len, input_size],其中:

    • batch_size 是批量大小。

    • seq_len 是序列的长度。

    • input_size 是每个时间步输入特征的维度。

  • 示例:如果输入是一个单词序列,且每个单词通过嵌入层映射为 128 维的向量,则 input_size=128

2. hidden_size
  • 类型int

  • 含义:隐藏状态的维度。

  • 解释:GRU 的隐藏状态维度决定了模型内部状态的大小。输出的隐藏状态形状为 [batch_size, seq_len, hidden_size]

  • 示例:如果 hidden_size=256,则每个时间步的隐藏状态是一个 256 维的向量。

3. num_layers
  • 类型int

  • 默认值1

  • 含义:GRU 的层数。

  • 解释:可以堆叠多个 GRU 层,每一层的输出作为下一层的输入。增加层数可以增强模型的表达能力,但也会增加计算复杂度。

  • 示例:如果 num_layers=2,则有两层 GRU,第一层的输出会传递给第二层。

4. bias
  • 类型bool

  • 默认值True

  • 含义:是否在 GRU 的权重矩阵中添加偏置项。

  • 解释:如果设置为 False,则在计算过程中不会使用偏置项,这可以减少模型的参数数量,但可能会影响模型的性能。

5. batch_first
  • 类型bool

  • 默认值False

  • 含义:输入和输出张量的第一个维度是否是批量大小。

  • 解释

    • 如果 batch_first=True,输入和输出的形状为 [batch_size, seq_len, input_size]

    • 如果 batch_first=False,输入和输出的形状为 [seq_len, batch_size, input_size]

  • 示例:在大多数实际应用中,为了方便处理批量数据,通常设置 batch_first=True

6. dropout
  • 类型float

  • 默认值0.0

  • 含义:在 GRU 的每一层之间应用的 dropout 概率。

  • 解释dropout 用于防止过拟合,通过在训练过程中随机丢弃一些神经元的输出来增强模型的泛化能力。该参数仅在 num_layers > 1 时有效。

  • 示例:如果 dropout=0.5,则在每一层之间有 50% 的概率丢弃神经元的输出。

7. bidirectional
  • 类型bool

  • 默认值False

  • 含义:是否使用双向 GRU。

  • 解释

    • 如果 bidirectional=True,则 GRU 会同时处理序列的正向和反向信息,输出的隐藏状态维度会加倍(2 * hidden_size)。

    • 如果 bidirectional=False,则 GRU 只处理序列的正向信息。

  • 示例:在一些任务中(如文本分类、机器翻译等),双向 GRU 可以更好地捕捉上下文信息。

输入和输出

输入
  • 输入张量input

    • 形状[batch_size, seq_len, input_size](如果 batch_first=True)或 [seq_len, batch_size, input_size](如果 batch_first=False)。

    • 含义:输入序列,每个时间步的特征维度为 input_size

  • 初始隐藏状态h_0

    • 形状[num_layers * num_directions, batch_size, hidden_size]

    • 含义:初始隐藏状态,num_directions 是方向的数量(单向为 1,双向为 2)。

    • 默认值:如果未提供,则默认为全零张量。

输出
  • 输出张量output

    • 形状[batch_size, seq_len, num_directions * hidden_size](如果 batch_first=True)或 [seq_len, batch_size, num_directions * hidden_size](如果 batch_first=False)。

    • 含义:每个时间步的隐藏状态。

  • 最终隐藏状态h_n

    • 形状[num_layers * num_directions, batch_size, hidden_size]

    • 含义:序列处理结束后的最终隐藏状态。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2334774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从零开始学Python游戏编程19-游戏循环模式1

在《从零开始学Python游戏编程18-函数3》中提到,可以对游戏代码进行重构,把某些代码写入函数中,主程序再调用这些函数,这样使得代码程序更容易理解和维护。游戏循环模式实际上也是把代码写入到若干个函数中,通过循环的…

Java获取终端设备信息工具类

在很多场景中需要获取到终端设备的一些硬件信息等,获取的字段如下: 返回参数 参数含义备注systemName系统名称remoteIp公网iplocalIp本地ip取IPV4macmac地址去掉地址中的"-“或”:"进行记录cpuSerialcpu序列号hardSerial硬盘序列号drive盘符…

【Linux网络与网络编程】08.传输层协议 UDP

传输层协议负责将数据从发送端传输到接收端。 一、再谈端口号 端口号标识了一个主机上进行通信的不同的应用程序。在 TCP/IP 协议中,用 "源IP","源端口号","目的 IP","目的端口号"&…

没音响没耳机,把台式电脑声音播放到手机上

第一步,电脑端下载安装e2eSoft VSC虚拟声卡(安装完成后关闭,不要点击和设置) 第二步,电脑端下载安装(SoundWire Server)(安装完成后不要关闭,保持默认配置) 第…

XDocument和XmlDocument的区别及用法

因为这几天用到了不熟悉的xml统计数据,啃了网上的资料解决了问题,故总结下xml知识。 1.什么是XML?2.XDocument和XmlDocument的区别3.XDocument示例1示例2:示例3: 4.XmlDocument5.LINQ to XML6.XML序列化(Serialize)与反序列化(De…

Blender安装基础使用教程

本博客记录安装Blender和基础使用,可以按如下操作来绘制标靶场景、道路标识牌等。 目录 1.安装Blender 2.创建面板资源 步骤 1: 设置 Blender 场景 步骤 2: 创建一个平面 步骤 3: 将 PDF 转换为图像 步骤 4-方法1: 添加材质并贴图 步骤4-方法2:创…

【Git】从零开始使用git --- git 的基本使用

哪怕是野火焚烧,哪怕是冰霜覆盖, 依然是志向不改,依然是信念不衰。 --- 《悟空传》--- 从零开始使用git 了解 Gitgit创建本地仓库初步理解git结构版本回退 了解 Git 开发场景中,文档可能会经历若干版本的迭代。假如我们不进行…

Android 中支持旧版 API 的方法(API 30)

Android 中最新依赖库的版本支持 API 31 及以上版本,若要支持 API30,则对应的依赖库的版本就需要使用旧版本。 可通过修改模块级 build.gradle 文件来进行适配。 1、android 标签的 targetSdk 和 compileSdk 版本号 根据实际目标设备的 android 版本来…

[特殊字符] Hyperlane:Rust 高性能 HTTP 服务器库,开启 Web 服务新纪元!

🚀 Hyperlane:Rust 高性能 HTTP 服务器库,开启 Web 服务新纪元! 🌟 什么是 Hyperlane? Hyperlane 是一个基于 Rust 语言开发的轻量级、高性能 HTTP 服务器库,专为简化网络服务开发而设计。它支…

RIP V2路由协议配置实验CISCO

1.RIP V2简介: RIP V2(Routing Information Protocol Version 2)是 RIP 路由协议的第二版,属于距离矢量路由协议,主要用于中小型网络环境。相较于 RIP V1,RIP V2 在功能和性能上进行了多项改进&#xff0c…

《LNMP架构+Nextcloud私有云超维部署:量子级安全与跨域穿透实战》

项目实战-使用LNMP搭建私有云存储 准备工作 恢复快照,关闭安全软件 [rootserver ~]# setenforce 0[rootserver ~]# systemctl stop firewalld搭建LNMP环境 [rootserver ~]# yum install nginx mariadb-server php* -y# 并开启nginx服务并设置开机自启 [r…

3DMAX笔记-UV知识点和烘焙步骤

1. 在展UV时,如何点击模型,就能选中所有这个模型的uv 2. 分多张UV时,不同的UV的可以设置为不同的颜色,然后可以通过颜色进行筛选。 3. 烘焙步骤 摆放完UV后,要另存为一份文件,留作备份 将模型部件全部分成…

【新人系列】Golang 入门(十三):结构体 - 下

✍ 个人博客:https://blog.csdn.net/Newin2020?typeblog 📝 专栏地址:https://blog.csdn.net/newin2020/category_12898955.html 📣 专栏定位:为 0 基础刚入门 Golang 的小伙伴提供详细的讲解,也欢迎大佬们…

Spring Boot 自定义商标(Logo)的完整示例及配置说明( banner.txt 文件和配置文件属性信息)

Spring Boot 自定义商标(Logo)的完整示例及配置说明 1. Spring Boot 商标(Banner)功能概述 Spring Boot 在启动时会显示一个 ASCII 艺术的商标 LOGO(默认为 Spring 的标志)。开发者可通过以下方式自定义&a…

Ubuntu虚拟机Linux系统入门

目录 一、安装 Ubuntu Linux 20.04系统 1.1 安装前准备工作 1.1.1 镜像下载 1.1.2 创建新的虚拟机 二、编译内核源码 2.1 下载源码 2.2 指定编译工具 2.3 将根文件系统放到源码根目录 2.4 配置生成.config 2.5 编译 三、安装aarch64交叉编译工具 四、安装QEMU 五、…

【蓝桥杯】2025省赛PythonB组复盘

前言 昨天蓝桥杯python省赛B组比完,今天在洛谷上估了下分,省一没有意外的话应该是稳了。这篇博文是对省赛试题的复盘,所给代码是省赛提交的代码。PB省赛洛谷题单 试题 A: 攻击次数 思路 这题目前有歧义,一个回合到底是只有一个…

【数据结构_4下篇】链表

一、链表的概念 链表,不要求在连续的内存空间,链表是一个离散的结构。 链表的元素和元素之间,内存是不连续的,而且这些元素的空间之间也没有什么规律: 1.顺序上没有规律 2.内存空间上也没有规律 *如何知道链表中包…

音视频 五 看书的笔记 MediaCodec

MediaCodec 用于访问底层媒体编解码器框架,编解码组件。通常与MediaExtractor(解封装,例如Mp4文件分解成 video和audio)、MediaSync、MediaMuxer(封装 例如音视频合成Mp4文件)、MediaCrypto、Image(cameraX 回调的ImageReader对象可以获取到Image帧图像,可转换成YU…

ubuntu 系统安装Mysql

安装 mysql sudo apt update sudo apt install mysql-server 启动服务 sudo systemctl start mysql 设置为开机自启 sudo systemctl enable mysql 查看服务状态 (看到类似“active (running)”的状态信息代表成功) sudo systemctl status mysql …

selenium快速入门

一、操作浏览器 from selenium import webdriver from selenium.webdriver.chrome.options import Options from selenium.webdriver.chrome.service import Service from selenium.webdriver.common.by import By# 设置选项 q1 Options() q1.add_argument("--no-sandbo…