[2022-11-26]神经网络与深度学习第5章 - 循环神经网络(part 2)

news2024/11/28 6:48:59

contents

  • 循环神经网络(part 2) - 梯度爆炸实验
    • 写在开头
    • 解决方式概览
    • 梯度爆炸实验
      • 梯度打印函数
      • 思考:什么是范数、L2范数、为什么要打印梯度范数
      • 复现梯度爆炸现象
      • 使用梯度截断解决梯度爆炸问题
      • 思考:梯度截断解决梯度爆炸问题的原理?
    • 写在最后

循环神经网络(part 2) - 梯度爆炸实验

写在开头

经过前面的实验我们不难发现,我们所构建的简单神经网络,对于较短的序列预测效果尚佳,但是应对起长的序列效果却一塌糊涂。对此,就不得不说简单循环神经网络对于长程依赖问题效果不好的两大原因:梯度爆炸和梯度消失。
在这里插入图片描述

解决方式概览

  • 对于梯度爆炸问题:我们通过权重衰减或者梯度截断能够较好地避免;
  • 对于梯度消失问题:由于模型存在的故有问题,我们可以通过改变模型,如使用长短期记忆网络LSTM来进行缓解。

梯度爆炸实验

本次实验,我们将复现梯度爆炸的问题,然后尝试使用梯度截断的方式进行解决。
采用长度为20的数据集进行实验,并在训练过程中输出 W , U , b W, U, b W,U,b的梯度向量的范数来衡量梯度变化情况。

梯度打印函数

在训练过程中打印梯度,分别定义W_listU_listb_list,用于存储前面所说的梯度向量的范数。代码如下:

W_list, U_list, b_list = [], [], []

def display_gvec(model):
    grad_w_l2, grad_u_l2, grad_b_l2 = 0., 0., 0.
    for name, param in model.named_parameters():
        if name == "rnn_model.W":
            grad_w_l2 = torch.norm(param.grad, p=2).numpy()
        if name == "rnn_model.U":
            grad_u_l2 = torch.norm(param.grad, p=2).numpy()
        if name == "rnn_model.b":
            grad_b_l2 = torch.norm(param.grad, p=2).numpy()
    print(f"grad_w_l2: {grad_w_l2:.5f}, grad_u_l2: {grad_u_l2:.5f}, grad_b_l2: {grad_b_l2:.5f} ")
    W_list.append(grad_w_l2)
    U_list.append(grad_u_l2)
    b_list.append(grad_b_l2)

思考:什么是范数、L2范数、为什么要打印梯度范数

L2和L-P范数定义如下:
L 2 = ∣ ∣ x ∣ ∣ 2 = ( ∑ i ∣ x i ∣ 2 ) 1 / 2 L P = ∣ ∣ x ∣ ∣ P = ( ∑ i ∣ x i ∣ P ) 1 / P L_2=||x||_2=(\sum_i|x_i|^2)^{1/2}\\ L_P=||x||_P=(\sum_i|x_i|^P)^{1/P} L2=x2=(ixi2)1/2LP=xP=(ixiP)1/P
在这里插入图片描述

这边推荐一个博客写的非常好:传送门:一文搞懂深度学习正则化的L2范数

至于为什么要打印L2范数,个人觉得是因为L2范数对于参数数值变化更为直观清晰。

复现梯度爆炸现象

为了更好地复现梯度爆炸问题,使用SGD优化器将批大小和学习率调大,设置学习率为0.2,同时计算交叉熵损失时reduction设置为sum,表示将损失进行累加。
获取训练过程中关于W,U和b参数梯度的L2范数,并将其绘制为图片。

  • 选取Tanh函数,因为其饱和区导数接近于0。由于梯度的急剧变化,参数数值变的较大或较小,容易落入梯度饱和区,导致梯度为0,模型很难继续训练。

代码如下:

num_epochs = 20
lr = 0.2
num_digits = 10
input_size = 32 # 将数字映射为向量的维度
hidden_size = 32 # 隐状态向量的维度
num_classes = 19
batch_size = 64
save_dir = "./checkpoints"

length = 20
print(f"\n====> Training SRN with data of length {length}.")

# 加载长度为length的数据
data_path = f"E:/nndl/misc/datasets/{length}"
train_examples, dev_examples, test_examples = load_data(data_path)
train_set, dev_set, test_set = DigitSumDataset(train_examples), DigitSumDataset(dev_examples),DigitSumDataset(test_examples)
train_loader = DataLoader(train_set, batch_size=batch_size)
dev_loader = DataLoader(dev_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)
# 实例化模型
base_model = SRN(input_size, hidden_size)
model = Model_RNN4SeqClass(base_model, num_digits, input_size, hidden_size, num_classes)
# 指定优化器
optimizer = torch.optim.SGD(model.parameters(),lr)
# 定义评价指标
metric = Accuracy()
# 定义损失函数
loss_fn = nn.CrossEntropyLoss(reduction="sum")

# 基于以上组件,实例化Runner
runner = RunnerV3(model, optimizer, loss_fn, metric)

# 进行模型训练
model_save_path = os.path.join(save_dir, f"srn_explosion_model_{length}.pdparams")
runner.train(train_loader, dev_loader, num_epochs=num_epochs, eval_steps=100, log_steps=1,
             save_path=model_save_path, additional={"grads l2 value ":display_gvec})

我们将训练时的梯度用matplotlib打印出来,结果如下:
在这里插入图片描述
在测试集上的结果如下:
在这里插入图片描述

使用梯度截断解决梯度爆炸问题

梯度截断是一种可以有效解决梯度爆炸问题的启发式方法,当梯度的模大于一定阈值时,将它截断为一个较小的数。一般有两种截断方式:按值截断和按模截断。本实验使用按模截断的方式解决梯度爆炸问题(Pascanu et al., 2013a)。公式如下:
在这里插入图片描述

其中 v 是范数上界,g 用来更新参数。因为所有参数(包括不同的参数组,如权重
和偏置)的梯度被单个缩放因子联合重整化,所以后一方法具有的优点是保证了每
个步骤仍然是在梯度方向上的,但实验表明两种形式类似。虽然参数更新与真实梯
度具有相同的方向梯度,经过梯度范数截断,参数更新的向量范数现在变得有界。这种有界梯度能避免执行梯度爆炸时的有害一步。
在torch中,我们使用torch.nn.utils.clip_grad_norm_函数即可进行梯度截断。其代码如下:

import warnings
import torch
from torch._six import inf
from typing import Union, Iterable
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def clip_grad_norm_(
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False) -> torch.Tensor:
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
        error_if_nonfinite (bool): if True, an error is thrown if the total
            norm of the gradients from :attr:`parameters` is ``nan``,
            ``inf``, or ``-inf``. Default: False (will switch to True in the future)

    Returns:
        Total norm of the parameter gradients (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    grads = [p.grad for p in parameters if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(grads) == 0:
        return torch.tensor(0.)
    device = grads[0].device
    if norm_type == inf:
        norms = [g.detach().abs().max().to(device) for g in grads]
        total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
    else:
        total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f'The total norm of order {norm_type} for gradients from '
            '`parameters` is non-finite, so it cannot be clipped. To disable '
            'this error and scale the gradients by the non-finite norm anyway, '
            'set `error_if_nonfinite=False`')
    clip_coef = max_norm / (total_norm + 1e-6)
    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
    # when the gradients do not reside in CPU memory.
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    for g in grads:
        g.detach().mul_(clip_coef_clamped.to(g.device))
    return total_norm

可以很清楚的看到如上述计算过程的代码用于实现梯度截断。

引入梯度截断后,我们重新对模型进行训练,结果如下:在这里插入图片描述

思考:梯度截断解决梯度爆炸问题的原理?

梯度截断可以使梯度下降在梯度爆炸的部分附近更合理地执行。将梯度看作地形图,那么梯度爆炸的地方就是一个悬崖,如果跳下去就会粉身碎骨(梯度爆炸导致的模型恶化),梯度截断则像传送门,将梯度传送到低的地方“安全着陆”。
又找到了一个描述:
通过梯度截断,可以较大程度的抑制梯度爆炸现象。如下图所示,图中曲面表
示的𝐽(𝑤, 𝑏)函数在不同网络参数 w,b 下的误差值𝐽,其中有一块区域𝐽(𝑤, 𝑏)函数的梯度变化
较大,一旦网络参数进入此区域,很容易出现梯度爆炸的现象,使得网络状态迅速恶化。
图右演示了添加梯度截断后的优化轨迹,由于对梯度进行了有效限制,使得每次更
新的步长得到有效控制,从而防止网络突然恶化。
在这里插入图片描述

写在最后

通过本次实验,我们了解到了在循环神经网络中一个非常需要重视的问题——梯度爆炸,在不改变模型本身大体结构的情况下,我们通过对梯度进行截断从而遏制梯度爆炸的问题。这个方法简单粗暴,但是非常好用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/41957.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

搭建MinIO容器

文章目录1 问题背景2 资源准备3 安装Docker服务4 关闭防火墙5 以Docker方式安装MinIO6 访问MinIO1 问题背景 玩一个前后端的项目&#xff0c;需要用到对象存储器&#xff0c;于是使用开源的MinIO。期间以Docker方式搭建遇到某些坑&#xff0c;此处仅以博客的方式记录下来 2 资源…

【通信原理课设--基于MATLAB/Simulink的2ASK数字带通传输系统建模与仿真】Simulink的使用介绍以及在本实验中的使用

目录 Simulink的简要介绍 Simulink的使用流程 进入Simulink 进入模型编辑窗口 ​ 建立一个新的文件 根据求需建立模型 对选择的模块进行参数设置 本次课程设计需要使用Simulink做ASK的仿真处理&#xff0c;那么下面就一起学习了解一下Simulink吧&#xff01; Simuli…

全球经济自由度1995-2021最新版绿色金融指数2001-2020

&#xff08;1&#xff09;全球经济自由度指数 1995-2021 1、数据来源&#xff1a;美国传统基金会 2、时间跨度&#xff1a;1995-2021 3、区域范围&#xff1a;全球 4、指标说明&#xff1a; 经济自由度指数&#xff0c;是由《华尔街日报》和美国传统基金会发布的年度报告…

爱站网关键词挖掘工具-长尾关键词挖掘站长工具

长尾词挖掘免费工具&#xff0c;为什么我们要使用长尾词挖掘免费工具&#xff0c;我们只要找准关键词就等于掌握了流量。 关键词可应用于任何平台&#xff1a;不管是网站、短视频、自媒体等&#xff01; 比如说用户A经常看体育领域的内容&#xff0c;平台就会给A打上体育领域标…

LeetCode刷题---19. 删除链表的倒数第 N 个结点(双指针-快慢指针)

文章目录一、编程题&#xff1a;142. 环形链表 II&#xff08;双指针-快慢指针&#xff09;1.题目描述2.示例1&#xff1a;3.示例2&#xff1a;4.示例3&#xff1a;5.提示&#xff1a;6.进阶&#xff1a;二、解题思路1.思路2.复杂度分析&#xff1a;3.算法图解三、代码实现总结…

[附源码]计算机毕业设计springboot班级事务管理论文2022

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

MQTT协议

1.MQTT基础知识学习 MQTT协议基础知识视频学习资源:太极创客 太极创客的MQTT基础篇博文目录链接:零基础入门学用物联网 – MQTT基础篇 – 目录 备注:建议先学习一下MQTT的基础知识后再学习下面章节部分的内容。 2.MQTT服务器的搭建 MQTT服务器的搭建步骤如下: 使用docker pull…

【vue + echarts】图表自适应缩放(跟随浏览器的窗口缩放,项目侧边栏折叠后的窗口缩放),图表重绘

效果图: 先清楚两个东西,浏览器窗口的缩放和项目侧边栏折叠后窗口的缩放,这两个是不一样的 第一种,浏览器窗口缩放后,当前窗口会放大了或者缩小了,这时会走浏览器缩放的代码部分, 第二种,而项目侧边栏折叠后窗口的缩放,虽然项目里面的窗口缩放了,但是,浏览器的窗口并没有发生…

tACS恢复老年人认知控制能力的EEG功能和DTI结构网络机制

认知控制能力是大多数日常任务中的关键能力&#xff0c;与年龄相关的认知控制能力下降威胁到个人的独立性。作者之前在老年人和年轻人中都发现&#xff0c;经颅交流电刺激&#xff08;tACS&#xff09;可以改善认知控制&#xff0c;在远离受刺激部位和频率之外的神经区域观察到…

使用python玩转二维码!速学速用!⛵

&#x1f4a1; 作者&#xff1a;韩信子ShowMeAI &#x1f4d8; Python3◉技能提升系列&#xff1a;https://www.showmeai.tech/tutorials/56 &#x1f4d8; 本文地址&#xff1a;https://showmeai.tech/article-detail/398 &#x1f4e2; 声明&#xff1a;版权所有&#xff0c;…

前端开发如何做新手引导

通常&#xff0c;在产品发布新版本或者有新功能上线时&#xff0c;都会开发一个新手引导功能来引导用户了解应用的功能。在前端开发中&#xff0c;如何快速地开发新手引导功能呢&#xff0c;下面介绍几个开箱即用的新手引导组件库。 1&#xff0c;Intro.js Intro.js是一个使用…

外汇天眼:外汇杠杆的“诱惑”到底有多大,为何做外汇的人都那么上瘾?

近些年随着外汇保证金在中国的持续发展&#xff0c;中国的外汇保证金交易就像当初的股票市场一样&#xff0c;从无到有&#xff0c;不断的发展壮大&#xff0c;再加上国内金融市场对外开放步伐加快&#xff0c;在中国国内参与外汇市场的投资者也是连年上升&#xff0c;那么这个…

【EI会议2023】12.20之后ddl

csdn 摘出来上文中的一些ddl ICET 2023(成都 5月12日-5月15日) http://www.icet.net/track9.html 截稿时间2022.12.20 通知录用:2023.1.20 SEGRE 2023(长沙 4月21日-4月23日) http://www.icsegre.org/ 截稿时间2023.2.26 通知录用:2023.4.3 ICIBA 2023(重庆 5月26日-…

全同态加密:GSW

参考文献&#xff1a; Micciancio D, Peikert C. Trapdoors for lattices: Simpler, tighter, faster, smaller[C]//Annual International Conference on the Theory and Applications of Cryptographic Techniques. Springer, Berlin, Heidelberg, 2012: 700-718.Gentry C, S…

Mysql-解决Truncated incorrect DOUBLE value xxx

问题 出现这种问题一般来说就是多表操作的时候, 使用的字段类型不一致导致的(查询除外),我们来看下真实案例 在hd_user表中parentId是binint类型 而在hd_user_increment_copy1表中parentId是varchar类型, 如果只是查询的话那么是不会报错的,我把查询sql提出了运行是可以的 …

[附源码]计算机毕业设计springboot保护濒危动物公益网站

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Java 编程性能调优

把 Java 性能调优分成 5 个层级&#xff1a;Java 编程、多线程、JVM 性能检测、设计模式、数据库性能&#xff0c;每个层级下都覆盖了最常见的优化问题。下面分别给你梳理一下&#xff1a; 可参考地址&#xff1a;Java性能调优全攻略来了_着火点的博客-CSDN博客_java性能调优 …

软件设计与体系结构编程题汇总

现在需要开发一款游戏软件&#xff0c;请以单例模式来设计其中的 Boss 角色。角色的属性和动作可以任意设计。 要求&#xff1a;该 Boss 类可以在多线程中使用。&#xff08; 8 分&#xff09; Public class Boss{Private static Boss instance; //(2 分&#xff09;Private …

vivo大数据日志采集Agent设计实践

作者&#xff1a;vivo 互联网存储技术团队- Qiu Sidi 在企业大数据体系建设过程中&#xff0c;数据采集是其中的首要环节。然而&#xff0c;当前行业内的相关开源数据采集组件&#xff0c;并无法满足企业大规模数据采集的需求与有效的数据采集治理&#xff0c;所以大部分企业都…

车间工厂看板还搞不定,数据可视化包教包会

在智能工厂的建设过程中&#xff0c;为了让每条生产线的生产进度和状态更加清晰&#xff0c;经常需要将生产信息情况显示在电视看板上&#xff0c;称为智能工厂-车间数据可视化大屏方案。 根据工厂和车间的大小&#xff0c;可能会使用 10到100 台甚至更多的电视看板来显示数据…