详解三种常用标准化:Batch Norm Layer Norm RMSNorm

news2024/10/9 16:32:42

Normalization

  Normalization现在已经成了神经网络中不可缺少的一个重要模块了,并且存在多种不同版本的归一化方法,其本质都是减去均值除以方差,进行线性映射后,使得数据满足某个稳定分布,如下图所示:
在这里插入图片描述
在这里插入图片描述
  深度学习中,归一化是常用的稳定训练的手段,CV 中常用 Batch Norm; Transformer 类模型中常用 layer norm,而 RMSNorm 是近期很流行的 LaMMa 模型使用的标准化方法,它是 Layer Norm 的一个变体。值得注意的是,这里所谓的归一化严格讲应该称为 标准化Standardization ,有时也称为 白化whitening。它描述一种把样本调整到均值为 0,方差为 1 的缩放平移操作。使用这种方法可以消除输入数据的量纲,有利于随机初始化的网络训练。
  本文将对BatchNorm、LayerNorm、RMSNorm三种归一化进行介绍。详细讨论前,先粗略看一下 Batch Norm 和 Layer Norm 的区别
在这里插入图片描述

  1. BatchNorm是对整个 batch 样本内的每个特征做归一化,这消除了不同特征之间的大小关系,但是保留了不同样本间的大小关系。BatchNorm 适用于 CV 领域,这时输入尺寸为 b × c × h × w b\times c\times h\times wb×c×h×w (批量大小x通道x长x宽),图像的每个通道 c cc 看作一个特征,BN 可以把各通道特征图的数量级调整到差不多,同时保持不同图片相同通道特征图间的相对大小关系
  2. LayerNorm是对每个样本的所有特征做归一化,这消除了不同样本间的大小关系,但是保留了一个样本内不同特征之间的大小关系。LayerNorm 适用于 NLP 领域,这时输入尺寸为 b × l × d (批量大小x序列长度x嵌入维度),如下图所示:
    在这里插入图片描述
      注意这时长 l 的 token 序列中,每个 token 对应一个长为 d 的特征向量,LayerNorm 会对各个 token 执行 l 次归一化计算,保留每个 token d 维嵌入内部的相对大小关系,同时拉近了不同 token 对应特征向量间的距离。与之相比,BN 会消除 d 维特征向量各维度之间的大小关系,破坏了 token 的特征(以下第 2 节会进一步说明这一点)

1. Batch Normalization

  BN 对同一 batch 内同一通道的所有数据进行归一化,设输入的 batch data 为 x,BN 运算如下
在这里插入图片描述
在这里插入图片描述
  注意我们在方差估计值中添加一个小的常量 ϵ (防除零因子),以确保我们永远不会尝试除以零。
在这里插入图片描述
  BatchNorm是一种在深度学习训练中广泛使用的归一化技术,有很多好处,包括正则化效应、减少过拟合、减少对权重初始值的依赖、允许使用更高的学习率等
  示例代码参考自《动手学深度学习》7.5 节,适用于全连接层和卷积层,训练过程中使用滑动平均法计算 batch 数据的均值和方差;评估过程中使用最新的均值和方差结果

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2: # 全连接层
            shape = (1, num_features)
        else:             # 卷积层
            shape = (1, num_features, 1, 1)
        
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def batch_norm(self, X, gamma, beta, moving_mean, moving_var, eps, momentum):
        if not torch.is_grad_enabled():
            # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
            X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
        else:
            assert len(X.shape) in (2, 4)
            if len(X.shape) == 2:
                # 使用全连接层的情况,计算特征维上的均值和方差
                mean = X.mean(dim=0)                                       # (num_features,)
                var = ((X - mean) ** 2).mean(dim=0)                        # (num_features,)
            else:
                # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
                mean = X.mean(dim=(0, 2, 3), keepdim=True)                # (1,num_features,1,1) 保持X的形状,以便后面可以做广播运算
                var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True) # (1,num_features,1,1)
                
            # 训练模式下,用当前的均值和方差做标准化
            X_hat = (X - mean) / torch.sqrt(var + eps)
            
            # 更新移动平均的均值和方差
            moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
            moving_var = momentum * moving_var + (1.0 - momentum) * var
            
        Y = gamma * X_hat + beta  # 缩放和移位
        return Y, moving_mean.data, moving_var.data

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var,复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
       
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = self.batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9
        )
        return Y

2. Layer Normalization

  LN 主要用于 NLP 领域,它对每个 token 的特征向量进行归一化计算。LN 运算如下在这里插入图片描述
在这里插入图片描述
  给定一个长 l 的句子,LN 要进行 l 次归一化计算,之后对每个特征维度施加统一的拉伸和偏移,如下图所示:
在这里插入图片描述
  为什么 LN 比 BN 更适用于 Transformer 类模型呢,这是因为 transformer 模型是基于相似度的,把序列中的每个 token 的特征向量进行归一化有利于模型学习语义,第一步调整均值方差时,相当于对把各个 token 的特征向量缩放到统一的尺度,第二步施加 γ , β 时,相当于对所有 token 的特征向量进行了统一的 transfer,这不会破坏 token 特征向量间的相对角度,因此不会破坏学到的语义信息。与之相对的,BN 沿着特征维度进行归一化,这时对序列中各个 token 施加的 transfer 是不同的,破坏了 token 特征向量间的相对角度关系

3. RMSNorm

  RMSNorm 是 LayerNorm 的一个简单变体,来自 2019 年的论文 Root Mean Square Layer Normalization,被 T5 和当前流行 lamma 模型所使用。其提出的动机是 LayerNorm 运算量比较大,所提出的RMSNorm 性能和 LayerNorm 相当,但是可以节省7%到64%的运算
  RMSNorm和LayerNorm的主要区别在于RMSNorm不需要同时计算均值和方差两个统计量,而只需要计算均方根 Root Mean Square 这一个统计量,公式如下在这里插入图片描述
  论文 Do Transformer Modifications Transfer Across Implementations and Applications? 中做了比较充分的对比实验,显示出RMS Norm的优越性。一个直观的猜测是,计算均值所代表的 center 操作类似于全连接层的 bias 项,储存到的是关于预训练任务的一种先验分布信息,而把这种先验分布信息直接储存在模型中,反而可能会导致模型的迁移能力下降
  下面给出 Transformer Lamma 源码中实现的 RMSNorm

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2195787.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开放式耳机是什么意思?分享几款适合各类运动佩戴的蓝牙耳机

其实目前有很多热爱运动的人士常常会陷入一个纠结之中,那就是在进行爬山、骑行、步行、跑步或者健身等各类运动的时候,到底佩戴什么样的蓝牙耳机才最为合适呢?那就我个人而言,我觉得开放式耳机无疑会是运动人士的救星。因为作为一…

OJ在线评测系统 微服务高级 网关跨域权限校验 集中解决跨域问题 拓展 JWT校验和实现接口限流降级

微服务网关跨域权限校验 集中的去解决一下跨域 这段代码是用来配置跨源资源共享(CORS)过滤器的。它创建了一个 CorsConfiguration 实例,允许所有方法和头部,并支持凭证(如 Cookies)。setAllowedOriginPat…

基于SSM+小程序的教育培训管理系统(教育3)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 基于SSM小程序的教育培训管理系统(教育3) 1、项目介绍 1、管理员权限操作的功能包括增删改查网课信息,教师信息,学生信息,试卷,…

多元化网络团队应对复杂威胁

GenAI、ML 和 IoT 等技术为威胁者提供了新的工具,使他们更容易针对消费者和组织发起攻击。 从诱骗受害者陷入投资骗局的Savvy Seahorse ,到使用 ChatGPT 之类的程序感染计算机并阅读电子邮件的自我复制 AI 蠕虫,新的网络威胁几乎每天都在出现…

【机器学习】探索机器学习在医疗影像分析中的应用

1. 🚀 引言1.1 🚀 医疗影像分析的现状与发展趋势1.2 📜 机器学习在医疗影像分析中的核心概念1.3 🏆 医疗影像分析在临床应用中的作用 2. 🔍 医疗影像分析的演变与创新2.1 🌟 医疗影像分析的发展历程2.2 &am…

SQl注入文件上传及sqli-labs第七关less-7

Sql注入文件上传 1、sql知识基础 secure_file_priv 参数 secure_file_priv 为 NULL 时,表示限制mysqld不允许导入或导出。 secure_file_priv 为 /tmp 时,表示限制mysqld只能在/tmp目录中执行导入导出,其他目录不能导出导入。 secure_fil…

linux信号 | 信号的补充知识

前言:本节内容主要是一些linux信号的周边知识或者补充知识。 对于信号的学习, 学习了信号概念, 产生, 保存与捕捉就已经算是认识我们的信号了。 如果想要知道更多关于信号的知识也可以看一下本篇文章。 ps:本篇内容适…

CSS——文字渐入效果

CSS——文字渐入效果 昨天制作了文字的打字机效果(CSS——文字打字机效果),然后我想到有些网页的文字效果是平滑渐入的,我就去思考这样的实现方式,其实就把之前的steps()函数去掉即可,但是我想换种实现方式…

电脑无法无线投屏的解决办法

在前司的时候经常遇到电脑无法使用无线投屏器的情况,今天就来聊聊如何解决。 1.不会连接。这种情况,经常发生在WIN10升级WIN11之后,一般是两种办法,一种是同时按键盘上的WINDOWS和K键,右下角就会出来连接的图标&#…

Day8:返回倒数第k个节点

题目: 实现一种算法,找出单向链表中倒数第k个节点。返回该结点的值。 示例: 输入:1->2->3->4->5和k2 输出:4 说明: 给定的k保证是有效的。 public int kthToLast(ListNode head,int k){…

《动手学深度学习》Pytorch 版学习笔记一:从预备知识到现代卷积神经网络

前言 笔者有一定的机器学习和深度学习理论基础,对 Pytorch 的实战还不够熟悉,打算入职前专项突击一下 本文内容为笔者学习《动手学深度学习》一书的学习笔记 主要记录了代码的实现和实现过程遇到的问题,不完全包括其理论知识 引用&#x…

GRASP七大基本原则+纯虚构防变异

问题引出 软件开发过程中,需要设计大量的类,使他们交互以实现特定的功能性需求。但是不同的设计方式,对程序的非功能性需求(可扩展性,稳定性,可维护性等)的实现程度则完全不同。 有没有一种统一…

动态规划算法——三步问题

1.题目解析 2.算法原理 本题可以近似看做泰波那契数列,即小孩到第一个台阶需要一步,到第二个台阶则是到第一个台阶的步数加上第一阶到第二阶的步数,同理第三阶就是第二阶的步数加上第二阶到第三阶的步数,由于小孩只能走三步&#…

基于STM32的智能垃圾桶控制系统设计

引言 本项目设计了一个基于STM32微控制器的智能垃圾桶控制系统,能够通过超声波传感器检测手部动作,自动打开或关闭垃圾桶盖,提升用户的便利性和卫生性。该项目展示了STM32微控制器在传感器检测、伺服电机控制和嵌入式智能控制中的应用。 环…

在不支持WSL2的Windows环境下安装Redis并添加环境变量的方法

如果系统版本支持 WSL 2 可跳过本教程。使用官网提供的教程即可 官网教程 查看是否支持 WSL 2 如果不支持或者觉得麻烦可以按照下面的方式安装 下载 点击打开下载地址 下载 zip 文件即可 安装 将下载的 zip 文件解压到自己想要解压的地方即可。(注意&#x…

毕业设计选题:基于ssm+vue+uniapp的模拟考试小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

2024最新 Navicat Premium 17 简体中文版安装图文详细教程

Navicat 17 引入了一系列新特性,旨在提升用户体验和工作效率。以下是一些值得关注的新功能: ‌模型工作区的全面重新设计‌:包含了增强的图表设计、更强大的同步工具、数据字典支持等多项功能。这有助于在一个工作区中创建多个模型&#xff0…

集合论基础 - 离散数学系列(一)

目录 1. 集合的基本概念 什么是集合? 集合的表示方法 常见的特殊集合 2. 子集与幂集 子集 幂集 3. 集合的运算 交集、并集与补集 集合运算规则 4. 笛卡尔积 5. 实际应用 6. 例题与练习 例题1 练习题 总结 引言 集合论是离散数学的基础之一&#xff…

HarmonyOS第一课 04 应用程序框架基础-习题分析

判断题 1.在基于Stage模型开发的应用项目中都存在一个app.json5配置文件、以及一个或多个module.json5配置文件。T 正确(True) 错误(False) 这个答案是T - AppScope > app.json5:app.json5配置文件,用于声明应用的全局配置信息,比如应用…

利用大规模语言模型提高生物医学 NER 性能的新方法

概述 论文地址:https://arxiv.org/pdf/2404.00152.pdf 大规模语言模型在零拍摄和四拍摄任务中表现出色,但在生物医学文本的独特表达识别(NER)方面仍有改进空间。例如,Gutirrez 等人(2022 年)的…