从代码学习深度学习 - GRU PyTorch版

news2025/4/5 10:30:08

文章目录

  • 前言
  • 一、GRU模型介绍
    • 1.1 GRU的核心机制
    • 1.2 GRU的优势
    • 1.3 PyTorch中的实现
  • 二、数据加载与预处理
    • 2.1 代码实现
    • 2.2 解析
  • 三、GRU模型定义
    • 3.1 代码实现
    • 3.2 实例化
    • 3.3 解析
  • 四、训练与预测
    • 4.1 代码实现(utils_for_train.py)
    • 4.2 在GRU.ipynb中的使用
    • 4.3 输出与可视化
    • 4.4 解析
  • 五、工具函数解析
    • 5.1 Timer
    • 5.2 Accumulator
    • 5.3 try_gpu
  • 六、可视化与绘图
    • 6.1 代码实现
    • 6.2 解析
  • 总结


前言

在深度学习领域,循环神经网络(RNN)及其变种如GRU(Gated Recurrent Unit,门控循环单元)在处理序列数据时表现出色。相比传统RNN,GRU通过更新门(Update Gate)和重置门(Reset Gate)简化了结构,同时保持了对长期依赖关系的建模能力。本篇博客将通过PyTorch实现一个基于GRU的文本生成模型,结合《The Time Machine》数据集,逐步解析代码实现的全过程。从数据预处理到模型训练,再到结果可视化,我们将深入探讨每个模块的功能,并展示完整的代码实现。


一、GRU模型介绍

GRU(Gated Recurrent Unit,门控循环单元)是循环神经网络(RNN)的一种改进变种,由Kyunghyun Cho等人在2014年提出。它旨在解决传统RNN在处理长序列时面临的梯度消失问题,同时通过更简洁的结构提升计算效率。相比LSTM(长短期记忆网络),GRU减少了一个门控单元,使用更新门(Update Gate)和重置门(Reset Gate)来控制信息的流动,从而在保持性能的同时降低参数量。

1.1 GRU的核心机制

在这里插入图片描述

GRU的工作原理基于两个关键的门控单元:

  1. 更新门(Update Gate, z t z_t zt
    更新门决定当前时间步的隐藏状态在多大程度上保留上一时间步的隐藏状态,以及接受多少新输入的信息。其计算公式为:
    z t = σ ( W z ⋅ [ h t − 1 , x t ] + b z ) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz[ht1,xt]+bz)
    其中, σ \sigma σ是sigmoid激活函数, h t − 1 h_{t-1} ht1 是上一时间步的隐藏状态, x t x_t xt 是当前输入, W z W_z Wz b z b_z bz 是可训练的参数。

  2. 重置门(Reset Gate, r t r_t rt
    重置门控制前一时间步的隐藏状态在多大程度上影响当前候选隐藏状态的计算。其计算公式为:
    r t = σ ( W r ⋅ [ h t − 1 , x t ] + b r ) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

基于这两个门,GRU计算候选隐藏状态和新隐藏状态:

  • 候选隐藏状态( h ~ t \tilde{h}_t h~t
    h ~ t = tanh ⁡ ( W h ⋅ [ r t ⊙ h t − 1 , x t ] + b h ) \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) h~t=tanh(Wh[rt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2328528.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二叉树 递归

本篇基于b站灵茶山艾府的课上例题与课后作业。 104. 二叉树的最大深度 给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出&…

反常积分和定积分的应用 2

世界尚有同类 前言伽马函数的推论关于数学的思考平面图形的面积笛卡尔心形线伯努利双纽线回顾参数方程求面积星型线摆线 旋转体体积一般轴线旋转被积函数有负数部分曲线的弧长最后一个部分内容-旋转曲面侧表面积直角坐标系极坐标系参数方程 总结 前言 力大出奇迹。好好加油。 …

Element-plus弹出框popover,使用自定义的图标选择组件

自定义的图标选择组件是若依的项目的 1. 若依的图标选择组件 js文件,引入所有的svg图片 let icons [] // 注意这里的路径,一定要是自己svg图片的路径 const modules import.meta.glob(./../../assets/icons/svg/*.svg); for (const path in modules)…

思维链 Chain-of-Thought(COT)

思维链 Chain-of-Thought(COT):思维链的启蒙 3. 思维链 Chain-of-Thought(COT)存在问题?2. 思维链 Chain-of-Thought(COT)是思路是什么?1. 什么是 思维链 Chain-of-Thoug…

硬件电路(23)-输入隔离高低电平有效切换电路

一、概述 项目中为了防止信号干扰需要加一些隔离电路,而且有时传感器的信号是高有效有时是低有效,所以基于此背景,设计了一款方便实现高低电平有效检测切换电路。 二、应用电路

大模型学习二:DeepSeek R1+蒸馏模型组本地部署与调用

一、说明 DeepSeek R1蒸馏模型组是基于DeepSeek-R1模型体系,通过知识蒸馏技术优化形成的系列模型,旨在平衡性能与效率。 1、技术路径与核心能力 基础架构与训练方法‌ ‌DeepSeek-R1-Zero‌:通过强化学习(RL)训练&…

相机的曝光和增益

文章目录 曝光增益增益原理主要作用增益带来的影响增益设置与应用 曝光 参考:B站优致谱视觉 增益 相机增益是指相机在拍摄过程中对图像信号进行放大的一种操作,它在提高图像亮度和增强图像细节方面起着重要作用,以下从原理、作用、影响以…

Linux内核物理内存组织结构

一、系统调用sys_mmap 系统调用mmap用来创建内存映射,把创建内存映射主要的工作委托给do_mmap函数,内核源码文件处理:mm/mmap.c 二、系统调用sys_munmap 1、vma find_vma (mm, start); // 根据起始地址找到要删除的第一个虚拟内存区域 vma 2…

(多看) CExercise_05_1函数_1.2计算base的exponent次幂

题目: 键盘录入两个整数:底(base)和幂指数(exponent),计算base的exponent次幂,并打印输出对应的结果。(注意底和幂指数都可能是负数) 提示:求幂运算时,基础的思路就是先无脑把指数转…

Vuue2 element-admin管理后台,Crud.js封装表格参数修改

需求 表格数据调用列表接口,需要多传一个 Type字段,而Type字段的值 需要从跳转页面Url上面获取到,并赋值给Type,再传入列表接口中,最后拿到表格数据并展示 遇到的问题 需求很简单,但是因为表格使用的是统…

Tiktok矩阵运营中使用云手机的好处

Tiktok矩阵运营中使用云手机的好处 云手机在TikTok矩阵运营中能够大幅提高管理效率、降低封号风险,并节省成本,是非常实用的运营工具。TikTok矩阵运营使用云手机有很多优势,特别是对于需要批量管理账号、提高运营效率的团队来说。以下是几个…

Linux下调试器gdb_cgdb使用

文章目录 一、样例代码二、使用watchset var确定问题原因条件断点 一、样例代码 #include <stdio.h>int Sum(int s, int e) {int result 0;int i;for(i s; i < e; i){result i;}return result; }int main() {int start 1;int end 100;printf("I will begin…

Vite环境下解决跨域问题

在 Vite 开发环境中&#xff0c;可以通过配置代理来解决跨域问题。以下是具体步骤&#xff1a; 在项目根目录下找到 vite.config.js 文件&#xff1a;如果没有&#xff0c;则需要创建一个。配置代理&#xff1a;在 vite.config.js 文件中&#xff0c;使用 server.proxy 选项来…

超简单:Linux下opencv-gpu配置

1.下载opencv和opencv_contrib安装包 1&#xff09;使用命令下 git clone https://github.com/opencv/opencv.git -b 4.9.0 git clone https://github.com/opencv/opencv_contrib.git -b 4.9.02&#xff09;复制链接去GitHub下载然后上传到服务器 注意&#xff1a;看好版本&a…

泰博云平台solr接口存在SSRF漏洞

免责声明&#xff1a;本号提供的网络安全信息仅供参考&#xff0c;不构成专业建议。作者不对任何由于使用本文信息而导致的直接或间接损害承担责任。如涉及侵权&#xff0c;请及时与我联系&#xff0c;我将尽快处理并删除相关内容。 漏洞描述 SSRF漏洞是一种在未能获取服务器…

31天Python入门——第20天:魔法方法详解

你好&#xff0c;我是安然无虞。 文章目录 魔法方法1. __new__和__del__2. __repr__和__len__3. __enter__和__exit__4. 可迭代对象和迭代器5. 中括号[]数据操作6. __getattr__、__setattr__ 和 __delattr__7. 可调用的8. 运算符 魔法方法 魔法方法: Python中的魔法方法是一类…

ubantu22.04中搭建地图开发环境(qt5.15.2 + osg3.7.0 + osgearth3.7.1 + osgqt)

一、下载安装qt5.15.2 二、下载编译安装osg3.7.0 三、下载编译安装osgearth3.7.1 四、下载编译安装osgqt 五、二三维地图显示demo开发 六、成果展示&#xff1a; 已有功能&#xff1a;加载了dom影像、可以进行二三维地图切换显示、二维地图支持缩放和平移、三维地图支持旋转…

Bethune X 6发布:为小规模数据库环境打造轻量化智能监控巡检利器

3月31日&#xff0c;“奇点时刻・数智跃迁 -- 云和恩墨2025春季产品发布会”通过视频号直播的方式在线上举办。发布会上&#xff0c;云和恩墨副总经理熊军正式发布 zCloud 6.7和zData X 3.3两款产品新版本&#xff0c;同时也带来了 Bethune X 6——这款面向小规模数据库环境的智…

一文理解什么是中值模糊

目录 中值模糊的概念 中值模糊&#xff08;Median Blur&#xff09; 中值模糊的原理 示例&#xff1a;33 中值模糊 什么是椒盐噪声 椒盐噪声&#xff08;Salt-and-Pepper Noise&#xff09; 椒盐噪声的特点 OpenCV 中的 cv2.medianBlur() 函数格式 示例代码 中值模糊…

游戏引擎学习第192天

仓库:https://gitee.com/mrxiao_com/2d_game_4 回顾 我们现在正在编写一些界面代码&#xff0c;主要是用户界面&#xff08;UI&#xff09;&#xff0c;不过这里的UI并不是游戏内的用户界面&#xff0c;而是为开发者设计的用户界面。我们正在尝试做一些小的UI元素&#xff0c…