机器学习入门--门控循环单元(GRU)原理与实践

news2024/11/27 7:39:41

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)

# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)

# 定义GRU模型
class GRUNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUNet, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.gru(x)
        out = self.fc(out[:, -1, :])
        return out

input_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        loss_list.append(loss.item())
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()

# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1450919.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

超详细的介绍Python语句

一、 常用命令 在介绍Python语句之前,先介绍一下几个有用的Python命令。 dir(模块名或类名或变量名或表达式名):获得当前模块、变量对应类型、表达式计算值对应类的属性列表 type(变量名或表达式名):获取变量或表达式计算值的对…

[嵌入式系统-14]:常见实时嵌入式操作系统比较:RT-Thread、uC/OS-II和FreeRTOS、Linux

目录 一、实时嵌入式操作系统 1.1 概述 1.2 什么“实时” 1.3 什么是硬实时和软实时 1.4 什么是嵌入式 1.5 什么操作系统 二、常见重量级操作系统 三、常见轻量级嵌入式操作系统 3.1 概述 3.2 FreeRTOS 3.3 uC/OS-II 3.4 RT-Thread 3.5 RT-Thread、uC/OS-II、Free…

第5讲前端静态登录页面实现

前端静态登录页面实现 引入全局样式: main.js导入样式文件: import /assets/styles/border.css import /assets/styles/reset.css加路由: const routes [{path: /login,name: login,component: () > import(../views/Login.vue)} ]App…

pytorch tensor维度变换

目录 1. view/reshape2. squeeze/unsqueeze3. expand 扩展4. repeat5 .t转置6. transpose7. permute 1. view/reshape view(*shape) → Tensor 作用:类似于reshape,将tensor转换为指定的shape,原始的data不改变。返回的tensor与原始的tensor…

LD-802D-X6

LD-802D-X6足浴按摩器,买个给老人家,解决泡脚越泡越冷,调节温度和定式问题, 按摩功能老人体验说太痒,转太快了,哈哈 下面是安装步骤使用说明 其实这包零件就是安装底部4个轮子,4个轮子的中间滚…

单片机学习笔记---LCD1602

LCD1602介绍 LCD1602(Liquid Crystal Display)液晶显示屏是一种字符型液晶显示模块,可以显示ASCII码的标准字符和其它的一些内置特殊字符(比如日文的片假名),还可以有8个自定义字符 显示容量:…

基于GPT一键完成数据分析全流程的AI Agent: Streamline Analyst

大型语言模型(LLM)的兴起不仅为获取知识和解决问题开辟了新的可能性,而且催生了一些新型智能系统,例如旨在辅助用户完成特定任务的AI Copilot以及旨在自动化和自主执行复杂任务的AI Agent,使得编程、创作等任务变得高效…

gem5 garnet 合成流量: packet注入流程

代码流程 下图就是全部. 剩下文字部分是细节补充,但是内容不变: bash调用python,用python配置好configuration, 一个cpu每个tick运行一次,requestport发出pkt. bash 启动 python文件并配置 ./build/NULL/gem5.debug configs/example/garnet_synth_traffic.py \--num-cpus…

计算机网络——12DNS

DNS DNS的必要性 IP地址标识主机、路由器但IP地址不好记忆,不便于人类用使用(没有意义)人类一般倾向于使用一些有意义的字符串来标识Internet上的设备存在着“字符串”——IP地址的转换的必要性人类用户提供要访问机器的“字符串”名称由DN…

解线性方程组(二)——Jacobi迭代法求解(C++)

迭代法 相比于直接法求解,迭代法使用多次迭代来逐渐逼近解,其精度比不上直接法,但是其速度会比直接法快很多,计算精度可控,特别适用于求解系数矩阵为大型稀疏矩阵的方程组。 Jacobi迭代法 假设有方程组如下&#xf…

【C++】实现Date类的各种运算符重载

上一篇文章只实现了operator操作符重载&#xff0c;由于运算符较多&#xff0c;该篇文章单独实现剩余所有的运算符重载。继续以Date类为例&#xff0c;实现运算符重载&#xff1a; 1.Date.h #pragma once#include <iostream> #include <assert.h>using namespace …

WebSocket | 基于TCP的全双工通信网络协议

文章目录 1、介绍2、示例2.1、分析2.2、代码开发2.3、功能测试 ​&#x1f343;作者介绍&#xff1a;双非本科大三网络工程专业在读&#xff0c;阿里云专家博主&#xff0c;专注于Java领域学习&#xff0c;擅长web应用开发、数据结构和算法&#xff0c;初步涉猎Python人工智能开…

qt-C++笔记之打印所有发生的事件

qt-C笔记之打印所有发生的事件 code review! 文章目录 qt-C笔记之打印所有发生的事件1.ChatGPT问答使用 QApplication 的 notify 方法使用 QObject 的 event 方法 2.使用 QObject 的 event 方法3.使用 QApplication 的 notify 方法 1.ChatGPT问答 在Qt C中&#xff0c;若要打…

小米米家智能摄像头mp4多碎片手工恢复案例

小米米家智能摄像头mp4多碎片手工恢复案例 智能摄像头品牌中小米算是绝对的大厂&#xff0c;其采用的方案也是比较成熟比较典型的&#xff1a;日志截图1分钟1个文件。小米米家的智能摄像头之前处理过很多&#xff0c;这次来讲一个比较特殊的案例。 故障存储: 32G TF卡 fat…

HiveSQL——统计当前时间段的有客人在住的房间数量

注&#xff1a;参考文章&#xff1a; HiveSQL一天一个小技巧&#xff1a;如何统计当前时间点状态情况【辅助变量累计变换思路】_sql查询统计某状态出现的次数及累计时间-CSDN博客文章浏览阅读2k次&#xff0c;点赞6次&#xff0c;收藏8次。本文总结了一种当前时间点状态统计的…

Vue 进阶系列丨实现简易VueRouter

‍‍Vue 进阶系列教程将在本号持续发布&#xff0c;一起查漏补缺学个痛快&#xff01;若您有遇到其它相关问题&#xff0c;非常欢迎在评论中留言讨论&#xff0c;达到帮助更多人的目的。若感本文对您有所帮助请点个赞吧&#xff01; 2013年7月28日&#xff0c;尤雨溪第一次在 G…

springboot集成elk实现日志采集可视化

一、安装ELK 安装ELK组件请参考我这篇博客&#xff1a;windows下安装ELK(踩坑记录)_windows上安装elk教程-CSDN博客 这里不再重复赘述。 二、编写logstash配置 ELK组件均安装好并成功启动&#xff0c;进入到logstash组件下的config文件夹&#xff0c;创建logstash.conf配置…

Three.JS教程5 threejs中的材质

Three.JS教程5 threejs中的材质 一、什么是Three.js材质&#xff1f;二、Three.js的材质类型1. 材质类型2. 材质的共用属性&#xff08;1&#xff09;.alphaHash : Boolean&#xff08;2&#xff09;.alphaTest : Float&#xff08;3&#xff09;.alphaToCoverage : Boolean&am…

使用 Mermaid 创建流程图,序列图,甘特图

使用 Mermaid 创建流程图和图表 Mermaid 是一个流行的 JavaScript 库&#xff0c;用于创建流程图、序列图、甘特图和其他各种图表。它的简洁语法使得创建图表变得非常简单&#xff0c;无需复杂的绘图工具或专业的编程技能。在本文中&#xff0c;我们将讲解如何使用 Mermaid 来创…

卷积神经网络的基本结构

卷积神经网络的基本结构 与传统的全连接神经网络一样&#xff0c;卷积神经网络依然是一个层级网络&#xff0c;只不过层的功能和形式发生了变化。 典型的CNN结构包括&#xff1a; 数据输入层&#xff08;Input Layer&#xff09;卷积层&#xff08;Convolutional Layer&#x…