【深度学习实验】循环神经网络(一):循环神经网络(RNN)模型的实现与梯度裁剪

news2024/9/21 12:43:37

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 数据处理

2. rnn

测试

3. grad_clipping

4. 代码整合


        经验是智慧之父,记忆是智慧之母。

——谚语

一、实验介绍

        本实验介绍了一个简单的循环神经网络(RNN)模型,并探讨了梯度裁剪在模型训练中的应用。

ChatGPT:

        循环神经网络(Recurrent Neural Network,RNN)是一种在序列数据上进行建模和处理的神经网络模型。与传统的前馈神经网络(Feedforward Neural Network)不同,RNN具有循环连接,使其能够捕捉到序列数据中的时间依赖关系。

        RNN的主要特点是它可以接受任意长度的输入序列,并且可以通过记忆先前的信息来影响后续的输出。这种记忆能力使得RNN在处理序列数据方面非常有效,如自然语言处理(NLP)、语音识别、机器翻译等任务。

        RNN模型中的基本组件是循环单元(Recurrent Unit),最常见的是长短期记忆网络(Long Short-Term Memory,LSTM)和门控循环单元(Gated Recurrent Unit,GRU)。这些循环单元通过将当前输入与前一个时间步的隐藏状态结合起来,产生当前时间步的输出和新的隐藏状态。这种循环结构允许信息在网络中传递和更新,并保持对序列中的长期依赖关系的记忆。

        在训练RNN模型时,通常使用反向传播算法和梯度下降优化算法来更新网络参数,以使模型能够更好地适应输入序列的特征。

        需要注意的是,虽然RNN在处理序列数据方面具有优势,但也存在一些问题,如长期依赖问题和梯度消失/爆炸问题。为了应对这些问题,研究人员提出了一些改进的RNN变体,如带有门控机制的LSTM和GRU,以及更先进的模型,如变换器(Transformer)模型。

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

import torch

1. 数据处理

        与之前的模型有所不同,循环神经网络引入了隐藏状态时间步两个新概念。当前时间步的隐藏状态由当前时间的输入与上一个时间步的隐藏状态一起计算出。

         根据隐藏状态的计算公式,需要计算两次矩阵乘法和三次加法才能得到当前时刻的隐藏状态。这里通过代码说明: 该计算公式等价于将当前时刻的输入与上一个时间步的隐藏状态做拼接,将两个权重矩阵做拼接,然后对两个拼接后的结果做矩阵乘法。此处展示省略了偏置项。

# X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
torch.matmul(X, W_xh) + torch.matmul(H, W_hh)

上面是按照公式计算得到的结果,下面是拼接后计算得到的结果,两个结果完全相同

torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))

  • X是一个形状为(3, 1)的张量,表示输入。
  • W_xh是一个形状为(1, 4)的张量,表示输入到隐藏状态的权重。
  • H是一个形状为(3, 4)的张量,表示隐藏状态。
  • W_hh是一个形状为(4, 4)的张量,表示隐藏状态到隐藏状态的权重。

2. rnn

        定义了一个名为rnn的函数,用于执行循环神经网络的前向传播,在函数内部,通过遍历输入序列的每个时间步,逐步计算隐藏状态和输出。

def rnn(inputs, state, params):
    # inputs的形状:(时间步数量,批量大小,词表大小)
    W_xh, W_hh, b_h, W_hq, b_q = params
    H = state
    outputs = []
    # X的形状:(批量大小,词表大小)
    for X in inputs:
        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)
  • 参数:
    • inputs是一个形状为(时间步数量,批量大小,词表大小)的张量,表示输入序列。
    • state是一个形状为(批量大小,隐藏状态大小)的张量,表示初始隐藏状态。
    • params是一个包含了模型的参数的列表,包括W_xhW_hhb_hW_hqb_q
  • 对于每个时间步,
    • 使用tanh激活函数来更新隐藏状态
    • 根据更新后的隐藏状态,计算输出Y
    • 将输出添加到outputs列表中
  • 使用torch.cat函数将输出列表合并成一个张量,返回合并后的张量和最后一个隐藏状态 (H,)

测试

    inputs=torch.rand(10,3,50)
    params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]
    state=torch.rand((3,50))
    output=rnn(inputs,state,params)
    print(output)
  • inputs是一个形状为(10, 3, 50)的随机张量,表示模拟的输入序列
  • params是一个包含了随机参数的列表,与rnn函数中的参数对应
  • state是一个形状为(3, 50)的随机张量,表示初始隐藏状态
  • 调用rnn函数
  • 打印输出结果output

3. grad_clipping

        在循环神经网络的训练中,当时间步较大时,可能导致数值不稳定, 例如梯度爆炸或梯度消失,所以一个很重要的步骤是梯度裁剪。通过下面的函数,梯度范数永远不会超过给定的阈值, 并且更新后的梯度完全与的原始方向对齐。

def grad_clipping(net, theta):
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

        函数接受两个参数:net和theta。该函数首先根据net的类型获取需要梯度更新的参数,然后计算所有参数梯度的平方和的平方根,并将其与阈值theta进行比较。如果超过阈值,则对参数梯度进行裁剪,使其不超过阈值。

4. 代码整合

# 导入必要的工具包
import torch

# # X为模拟的输入,H为模拟的隐藏状态,在实际情况时要更复杂一些
# X, W_xh = torch.normal(0, 1, (3, 1)), torch.normal(0, 1, (1, 4))
# H, W_hh = torch.normal(0, 1, (3, 4)), torch.normal(0, 1, (4, 4))
# # torch.matmul(X, W_xh) + torch.matmul(H, W_hh)
# #
# # torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))


def rnn(inputs, state, params):
    # inputs的形状:(时间步数量,批量大小,词表大小)
    W_xh, W_hh, b_h, W_hq, b_q = params
    H = state
    outputs = []
    # X的形状:(批量大小,词表大小)
    for X in inputs:
        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)


def grad_clipping(net, theta):
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm


if __name__ == '__main__':
    inputs=torch.rand(10,3,50)
    params=[torch.rand((50,50)),torch.rand((50,50)),torch.rand((3,50)),torch.rand((50,60)),torch.rand((3,60))]
    state=torch.rand((3,50))
    output=rnn(inputs,state,params)
    print(output)

        使用随机生成的输入数据和参数进行模型的测试。测试结果显示,RNN模型能够正确计算隐藏状态和输出结果,并且通过梯度裁剪可以有效控制梯度的大小,提高模型的稳定性和训练效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1080597.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot Redis 基础使用

redis是一个key-value。和Memcached类似,它支持存储的value类型相对更多,包括string(字符串)、list(链表)、set(集合)、zset(sorted set --有序集合)和hash(哈希类型)。 Redis能做什么: 1. 缓存,毫无疑问这…

【设计模式】使用建造者模式组装对象并加入自定义校验

文章目录 1.前言1.1.创建对象时的痛点 2.建造者模式2.1 被建造类准备2.2.建造者类实现2.3.构建对象测试2.4.使用lombok简化建造者2.5.lombok简化建造者的缺陷 3.总结 1.前言 在我刚入行不久的时候就听说过建造者模式这种设计模式,当时只知道是用来组装对象&#xf…

串的基本操作(数据结构)

串的基本操作 #include <stdlib.h> #include <iostream> #include <stdio.h> #define MaxSize 255typedef struct{char ch[MaxSize];int length; }SString;//初始化 SString InitStr(SString &S){S.length0;return S; } //为了方便计算&#xff0c;串的…

Java架构师部署架构设计

目录 1 导学2 部署架构设计和部署架构图3 实战整体部署架构设计4 节点部署说明列表5 总结1 导学 本章的主要内容是整体架构设计的核心之一,部署架构设计相关的一些知识落到项目上,就是系统系统的部署架构设计。在本章学习里面我们可以去去学习整部署架构构,主要是部署架构设…

大模型分布式训练并行技术(四)-张量并行

linkj 近年来&#xff0c;随着Transformer、MOE架构的提出&#xff0c;使得深度学习模型轻松突破上万亿规模参数&#xff0c;传统的单机单卡模式已经无法满足超大模型进行训练的要求。因此&#xff0c;我们需要基于单机多卡、甚至是多机多卡进行分布式大模型的训练。 而利用AI集…

基于 ACK Fluid 的混合云优化数据访问(二):搭建弹性计算实例与第三方存储的桥梁

作者&#xff1a;车漾 前文回顾&#xff1a; 本系列将介绍如何基于 ACK Fluid 支持和优化混合云的数据访问场景&#xff0c;相关文章请参考&#xff1a; 基于 ACK Fluid 的混合云优化数据访问&#xff08;一&#xff09;&#xff1a;场景与架构 在前文《场景与架构》中&…

Springboot使用RestTemplate调用第三方接口

配置 新建一个RestTemplate的配置类&#xff0c;如下&#xff1a; /*** RestTemplate配置项*/ Configuration public class RestTemplateConfig {Beanpublic RestTemplate restTemplate() {return new RestTemplate();}} 在controller中引入RestTemplate&#xff0c;如下&am…

PHP 车辆租赁系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 车辆租赁系统是一套完善的web设计系统&#xff0c;对理解php编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 php 车辆租赁管理系统1 代码 https://download.csdn.net/download/qq_41221322/88414512…

软信天成:AI驱动,化解企业数据的隐私之痛

近年来&#xff0c;在全面数字化的背景下&#xff0c;不断加剧的数据丢失和爆炸式的个人数据增长让数据隐私的需求变得空前迫切。为解决个人、企业等敏感数据被滥用的现状&#xff0c;各地监管机构开始逐渐完善区域内数据隐私法律法规。从《欧盟通用数据保护条例》&#xff08;…

档案宝档案管理系统在微信小程序上线了!

随着信息时代的到来&#xff0c;企业和组织面临着越来越多的信息和数据需要管理。而档案管理作为一项重要的任务&#xff0c;对于企业的运营和决策起着至关重要的作用。为了满足用户的需求&#xff0c;我们很高兴地宣布&#xff0c;档案宝档案管理系统已经在微信小程序上线了&a…

git clone项目报错文件名过长

1.最近从git上clone项目时&#xff0c;有时会遇到报错&#xff0c;报出Filename too long错误。 Filename too long fatal: unable to checkout working tree warning: Clone succeeded, but checkout failed. 2.原因 Windows支持最长260字符的文件名&#xff08;包括其路径在…

Linux-文件管理命令

绝对路径&#xff1a;从根目录开始描述的路径 pwd输入即为绝对路径&#xff0c; 开头一定是“/”&#xff0c;因为一定是从根目录开始走 相对路径&#xff1a;从当前路径开始描述的路径&#xff0c;开头不一定是“/”&#xff0c;因为不一定是从根目录开始走的 .:是当前目录 。…

无人值守配电室变电所运维解决方案

随着电力系统数字化、智能化的不断发展&#xff0c;无人值守配电室变电所已经成为一种趋势。为了确保变电所的安全稳定运行&#xff0c;本文提出了一种无人值守配电室变电所运维解决方案。 一、背景介绍 力安科技电易云无人值守配电室变电所是指通过远程监控和智能化电力数…

虚拟机软件Parallels Desktop 19 mac功能介绍

Parallels Desktop 19 mac是一款虚拟机软件&#xff0c;它允许用户在Mac电脑上同时运行Windows、Linux和其他操作系统。Parallels Desktop提供了直观易用的界面&#xff0c;使用户可以轻松创建、配置和管理虚拟机。 PD19虚拟机软件具有快速启动和关闭虚拟机的能力&#xff0c;让…

实盘股票杠杆平台排名报告:五大排名和评估一览

本报告旨在根据公开信息和用户反馈&#xff0c;对股票配资公司的五大排名进行深入分析。这些排名包括尚红网、倍悦网、兴盛网、诚利和、嘉正网、广瑞网、富灯网、天创网、恒正网、创通网。 一、尚红网 尚红网在股票配资公司中一直享有盛誉&#xff0c;其资金雄厚&#xff0c;…

Linux操作系统常见指令理解(1)

Linux是一种开源的操作系统&#xff0c;常用于服务器和嵌入式设备。以下是一些常见的Linux指令及其详解&#xff1a; ls&#xff1a;列出目录中的文件和子目录。 cd&#xff1a;改变当前工作目录。 pwd&#xff1a;显示当前工作目录的路径。 mkdir&#xff1a;创建一个新目录。…

告别繁琐开发,轻松玩转无代码!

导语&#xff1a;随着科技的快速发展&#xff0c;无代码开发已成为越来越多人的首选。告别繁琐的开发流程&#xff0c;无代码将让你轻松驾驭各种项目需求&#xff01;本文将带您了解无代码的魅力&#xff0c;以及如何玩转无代码&#xff0c;为您的项目插上腾飞的翅膀&#xff0…

前端-Vue-开发指南

VueJS 开源文档 拉入vscode安装node.js安装vue脚手架components : 组件router&#xff1a;路由创建新组建 &#xff1a;assets&#xff1a; 系统图片存放地址main.js&#xff1a; vue脚手架对象存放地 &#xff08;新的包要放在里面&#xff09;属性 computedslot 插槽error St…

基于Effect的组件设计 | 京东云技术团队

Effect的概念起源 从输入输出的角度理解Effect https://link.excalidraw.com/p/readonly/KXAy7d2DlnkM8X1yps6L 编程中的Effect起源于函数式编程中纯函数的概念 纯函数是指在相同的输入下&#xff0c;总是产生相同的输出&#xff0c;并且没有任何副作用(side effect)的函数。…

使用Perl脚本编写爬虫程序的一些技术问题解答

网络爬虫是一种强大的工具&#xff0c;用于从互联网上收集和提取数据。Perl 作为一种功能强大的脚本语言&#xff0c;提供了丰富的工具和库&#xff0c;使得编写的爬虫程序变得简单而灵活。在使用的过程中大家会遇到一些问题&#xff0c;本文将通过问答方式&#xff0c;解答一些…