抑制过拟合——从梯度的角度看LayerNorm的作用

news2025/1/22 16:47:29

抑制过拟合——从梯度的角度看LayerNorm的作用

  • Normalization的目的
  • LayerNorm & BatchNorm
  • 可视化分析LayerNorm
    • 分析loss
    • 分析梯度

  在深入探索transformer模型时,一个不可忽视的组成部分便是LayerNorm,它在模型的优化过程中起着关键作用。相比之下,虽然BatchNorm也广泛应用于各种网络模型中,但在很多情况下LayerNorm表现出更优的效果。然而,对于为何LayerNorm优于BatchNorm,目前学界还没有形成统一的看法。

在这里插入图片描述

  本文的重点是探讨LayerNorm在模型训练过程中对梯度变化的影响。通过对这一作用的深入理解,我们可以更加有效地应用LayerNorm,从而提升模型的性能。

Normalization的目的

  在使用梯度下降法进行优化的过程中,特别是在深层网络中,输入数据的特征分布会随着网络深度的增加而发生变化。为了维持数据特征分布的稳定性,通常会引入Normalization。这不仅能够使得模型使用更大的学习率,加速模型收敛,同时也有助于防止过拟合,使训练过程更加平稳。

  简而言之,Normalization的主要作用是在特征输入激活函数之前进行标准化处理,将数据转换为均值为0、方差为1的分布。这一处理避免了数据落入激活函数的饱和区,从而降低了梯度消失问题的风险。

  从更深层次来看,Normalization 通过将数据拉回标准正态分布,提高了网络运算的稳定性。由于神经网络的大部分操作都是矩阵运算,未经处理的向量在经过多次运算后其值可能逐渐增大。因此,为了维持网络的稳定性,定期将数据值拉回到正态分布显得尤为重要。

LayerNorm & BatchNorm

  在理解LayerNormBatchNorm的不同之处时,一个直观的示意图可以帮助我们更清晰地认识两者的区别。

在这里插入图片描述

  假设输入数据的维度为[batch_size, seq_len, emb_dim]。在这种情况下,LayerNorm是针对batch中的单个数据点的[seq_len, emb_dim]维度进行normalization,而BatchNorm则是针对[batch_size, seq_len]维度进行normalization

  考虑到文本任务中文本长度和词嵌入的特性,LayerNorm在处理[seq_len, emb_dim]normalization时通常会比BatchNorm更有效。

  具体来说,BN(Batch Normalization)在保留不同样本之间的大小关系的同时,抹平了不同特征之间的差异。这在依赖于样本间关系的任务中特别有效,例如在计算机视觉领域中对不同图片样本进行分类时。

  而LN(Layer Normalization)则是在保留不同特征之间的大小关系的同时,抹平了不同样本之间的差异。这使得LN特别适用于自然语言处理领域的任务,其中一个样本的特征实际上是由不同的词嵌入组成。通过LN,可以有效地保留这些特征间的时序关系。

可视化分析LayerNorm

  为了更深入地理解LayerNorm的作用,本文设计了四个实验来观察其对梯度变化的影响。这四个实验的模型结构如下表所示:

模型名称描述
实验1Dropout
实验2Dropout + LayerNorm
实验3LayerNorm
实验4None

  实验1 Dropout结构如下:

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20),
            nn.Linear(20, 20),
            nn.Dropout(0.1),
            nn.Linear(20, 20),
            nn.Linear(20, 20),
            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

  实验2 Dropout + LayerNorm结构如下:

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20),
            nn.Linear(20, 20),
            nn.Dropout(0.1),
            nn.Linear(20, 20),
            nn.LayerNorm(20),
            nn.Linear(20, 20),
            nn.LayerNorm(20),
            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

  实验3 LayerNorm结构如下:

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20),
            nn.Linear(20, 20),
            nn.Linear(20, 20),
            nn.LayerNorm(20),
            nn.Linear(20, 20),
            nn.LayerNorm(20),
            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

  实验4 None结构如下:

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20),
            nn.Linear(20, 20),
            nn.Linear(20, 20),
            nn.Linear(20, 20),
            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

  训练代码如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import time


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20),

            nn.Linear(20, 20),
            # nn.Dropout(0.1),

            nn.Linear(20, 20),
            # nn.LayerNorm(20),

            nn.Linear(20, 20),
            # nn.LayerNorm(20),

            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

lr = 0.01
iteration = 1000


x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1

model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()

start_time = time.time()
writer = SummaryWriter(comment='_层归一化')

for iter in range(iteration):
    y_pred = model(x)
    loss = loss_function(y, y_pred.squeeze())
    loss.backward()

    for name, layer in model.named_parameters():
        writer.add_histogram(name + '_grad', layer.grad, iter)
        writer.add_histogram(name + '_data', layer, iter)
    writer.add_scalar('loss', loss, iter)

    optimizer.step()
    optimizer.zero_grad()

    if iter % 50 == 0:
        print("iter: ", iter)

print("Time: ", time.time() - start_time)


这里我们使用 TensorBoardX 进行结果的可视化展示。

分析loss

  实验结果如下所示:

在这里插入图片描述

  可以看到,加入了LayerNorm的实验2和3的最终loss明显低于没有使用LayerNorm的实验1和4。

我们认为LayerNorm能够帮助降低模型损失的原因主要有以下几点:
1、稳定化学习过程: 层归一化通过对每个样本的特征进行独立归一化,有效减少了不同层输出分布的变化(也称为内部协变量偏移),从而有助于稳定网络的学习过程,使网络更容易学习。
2、加速收敛: 层归一化通过减少不同训练批次间的系统差异,可以显著加速神经网络的收敛速度,使网络能够更快地达到较低的损失值。

  通过对实验1、2和实验3、4的比较,我们发现加入了Dropout的实验1和2其损失值波动较大,而没有使用Dropout的实验3和4则显示出更为平滑的收敛过程。

分析梯度

  实验结果如下所示:

在这里插入图片描述
  通过深入分析这些实验结果,我们可以更好地理解LayerNorm在控制梯度分布方面的作用。让我们逐一探讨每个实验的结果和它们所揭示的洞见。

  实验4(无任何正则化):这个实验没有应用DropoutLayerNorm。其结果显示,梯度在训练的初期阶段迅速稳定下来,但分布非常集中。这种过于集中的分布通常是过拟合的迹象。过拟合意味着模型可能在训练数据上表现出色,但在未见过的数据上则表现不佳。这种现象在深度学习中非常常见,尤其是在没有足够的正则化措施时。

  实验1(仅使用Dropout):相比于实验4,实验1引入了Dropout层。Dropout是一种有效的正则化技术,它通过在训练过程中随机“关闭”一部分神经元来减少模型对特定训练样本的依赖。这导致了更加“杂乱无章”的梯度分布,这实际上是好事,因为它表明模型正在学习更多样化的特征,而不是仅仅依赖于特定的模式或数据点。

  实验2(Dropout + LayerNorm):这个实验在Dropout的基础上增加了LayerNormLayerNorm通过独立地标准化每个样本的特征来减少不同层之间的输出分布的变化,有助于进一步稳定训练过程。从实验结果可以看出,实验2的梯度分布比实验1更加集中和均匀,这表明结合DropoutLayerNorm可以提供更好的正则化效果,使得模型能够更有效地学习并防止过拟合。

  实验3(仅使用LayerNorm):最后,实验3专注于单独使用LayerNorm。这个实验的结果介于实验4和实验1之间。它没有实验4那样的过拟合梯度分布,也没有实验1中Dropout导致的极端波动。这表明LayerNorm自身是一个有效的正则化方法,能够平衡模型的学习过程,即使在没有Dropout的情况下也能防止过拟合。

  LayerNorm在维持梯度分布稳定性和提高模型泛化能力方面的重要作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1274072.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智能优化算法应用:基于平衡优化器算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于平衡优化器算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于平衡优化器算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.平衡优化器算法4.实验参数设定5.算法结果…

Android Bitmap裁剪/压缩/缩放到限定的最大宽高值,Kotlin

Android Bitmap裁剪/压缩/缩放到限定的最大宽高值&#xff0c;Kotlin private fun cropImage(image: Bitmap): Bitmap {val maxWidth 1024 //假设宽度最大值1024val maxHeight 1024 //假设高度最大值1024val width image.widthval height image.heightif (width < maxWi…

opencv知识库:cv2.add()函数和“+”号运算符

需求场景 现有一灰度图像&#xff0c;需求是为该图像增加亮度。 原始灰度图像 预期目标图像 解决方案 不建议的方案——“”运算符 假设我们需要为原始灰度图像的亮度整体提升88&#xff0c;那么利用“”运算符的源码如下&#xff1a; import cv2img_path r"D:\pych…

git的版本控制流程

1、git是一款版本控制工具 例如我们常用的淘宝&#xff0c;每次升级&#xff0c;版本号就会加一。那么我们怎么控制版本号呢&#xff1f; --使用git。 2、最常使用的git指令 git add . 暂存 git commit -m"***" 提交到本地 git pull 将远程仓库代码下拉到本地 git …

基于GAN的多尺度门合并多模态MRI图像合成

Multi-Modal MRI Image Synthesis via GAN With Multi-Scale Gate Mergence 基于GAN的多尺度门合并多模态MRI图像合成背景贡献实验方法生成器gate mergence (GM) strategy&#xff08;门控融合策略&#xff09;判别器 损失函数Thinking 基于GAN的多尺度门合并多模态MRI图像合成…

从零开始部署一个网站详细图文教程——腾讯云的服务器、SSL证书,阿里云的域名,七牛云的对象存储、CDN等

文章目录 前期准备连接服务器配置Golang环境安装配置MySQL安装配置Redis安装配置Nginx安装Node域名解析SSL证书下载启动项目配置CDN加速总结 前期准备 云服务器&#xff08;必备&#xff09;、已经备案的域名&#xff08;必备&#xff09;&#xff0c;已签发的SSL证书&#xf…

plt创建指定色系

1、创建不连续色系 import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap# 定义颜色的RGB值 colors [(0.2, 0.4, 0.6), # 蓝色(0.8, 0.1, 0.3), # 红色(0.5, 0.7, 0.2),(0.3,0.5,0.8)] # 绿色# 创建色系 cmap ListedColormap(colors)# 绘制…

STM32USART+DMA实现不定长数据接收/发送

STM32USARTDMA实现不定长数据接收 CubeMX配置代码分享实践结果 这一期的内容是一篇代码分享&#xff0c;CubeMX配置介绍&#xff0c;关于基础的内容可以往期内容 夜深人静学32系列11——串口通信夜深人静学32系列18——DMAADC单/多通道采集STM32串口重定向/实现不定长数据接收 …

3D点云目标检测:VoxelNex解读

VoxelNext 通用检测器 vs VoxelNext一、3D稀疏卷积模块1.1、额外的两次下采样消融实验结果代码 1.2、稀疏体素删减消融实验&#xff1a;代码 二、稀疏体素高度压缩代码 三、稀疏预测head 通用检测器 vs VoxelNext 一、3D稀疏卷积模块 1.1、额外的两次下采样 使用通用的3D spa…

2023年亚太杯数学建模C题新能源汽车(思路模型代码)

一、翻译 新能源汽车是指采用先进的技术原理、新技术和新结构&#xff0c;以非常规车用燃料&#xff08;非常规车用燃料是指汽油和柴油以外的燃料(非常规车用燃料是指汽油和柴油以外的燃料&#xff09;&#xff0c;并集成了汽车动力控制和驱动等先进技术的汽车。新能源汽车包括…

Gitee 之初体验(上)

我们在项目开发或者自己学习的时候&#xff0c;总会存在这样的问题&#xff1a; 在一台电脑上编写完代码&#xff0c;想要再另外一台电脑上再去写&#xff0c;再或者和其他人一起协作等等场合&#xff0c;代码传来传去很麻烦。 这个时候&#xff0c;我们就可以去使用代码管理工…

在java java.util.Date 已知逝去时间怎么求年月日 数学计算不用其他方法

在Java中&#xff0c;使用java.util.Date类已知逝去时间求年月日的方法如下&#xff1a; 首先&#xff0c;获取当前时间和逝去时间之间的毫秒数差值&#xff0c;可以使用Date类的getTime()方法获得时间戳。 将毫秒数转换为秒数&#xff0c;并计算出总共的天数。 根据总共的天…

计算机网络:应用层(上篇)

文章目录 前言一、应用层协议原理1.网络应用的体系结构2.进程通信 二、Web与HTTP1.HTTP概况2.HTTP连接3.HTTP请求报文4.用户-服务器状态&#xff1a;cookies5.Web缓存&#xff08;代理服务器&#xff09; 三、FTP&#xff1a;文件传输协议1.FTP&#xff1a;控制连接与数据连接分…

ClassNotFoundException: org.apache.hive.spark.client.Job

hive使用的是3.13版本&#xff0c;spark是3.3.3支持hadoop3.x hive将engine从mr改成spark&#xff0c;通过beeline执行insert、delete时一直报错&#xff0c;sparkTask rpc关闭&#xff0c; 查看yarn是出现ClassNotFoundException: org.apache.hive.spark.client.Job。 开始…

怎么一键批量转换PDF/图片为Excel、Word,从而提高工作效率?

在处理大量PDF、图片文件时&#xff0c;我们往往需要将这些文件转换成Word或Excel格式以方便编辑和统计分析。此时&#xff0c;金鸣表格文字识别大师这款工具可以发挥巨大作用。下面&#xff0c;我们就来探讨如何使用它进行批量转换&#xff0c;以实现高效处理。 一、准备工作…

linux服务器环境搭建(使用yum 安装mysql、jdk、redis)

一:yum的安装 1:下载yum安装包并解压 wget http://yum.baseurl.org/download/3.2/yum-3.2.28.tar.gz tar xvf yum-3.2.28.tar.gz 2.进入yum-3.2.28文件夹中进行安装,执行安装指令 cd yum-3.2.28 sudo apt install yum 3.更新版本 yum check-update yum update yum cle…

(一)C语言概述

文章目录 一、C语言1、计算机结构组成 二、第一个C语言程序&#xff1a;hello world1、编写C语言代码&#xff1a;hello.c2、通过gcc编译C代码&#xff08;1&#xff09;gcc编译器介绍&#xff08;2&#xff09;Window平台中gcc环境配置 3、代码分析&#xff08;1&#xff09;#…

基础课14——语音识别

ASR 是自动语音识别&#xff08;Automatic Speech Recognition&#xff09;的缩写&#xff0c;是一种将人类语音转换为文本的技术。ASR 系统可以处理实时音频流或已录制的音频文件&#xff0c;并将其转换为文本。它是一种自然语言处理技术&#xff0c;广泛应用于许多领域&#…

C++ :运算符重载

运算符重载&#xff1a; 运算符重载概念&#xff1a;对已有的运算符重新进行定义&#xff0c;赋予其另一种功能&#xff0c;以适应不同的数据类型 运算符的重载实际是一种特殊的函数重载&#xff0c;必须定义一个函数&#xff0c;并告诉C编译器&#xff0c;当遇到该重载的运算符…

每日一练2023.11.30——验证身份【PTA】

题目链接 &#xff1a;验证身份 题目要求&#xff1a; 一个合法的身份证号码由17位地区、日期编号和顺序编号加1位校验码组成。校验码的计算规则如下&#xff1a; 首先对前17位数字加权求和&#xff0c;权重分配为&#xff1a;{7&#xff0c;9&#xff0c;10&#xff0c;5&a…