【标准化方法】(2) Layer Normalization 原理解析、代码复现,附Pytorch代码

news2024/9/24 17:18:01

大家好,今天和各位分享一下深度学习中常见的标准化方法,在 Transformer 模型中常用的 Layer Normalization,从数学公式的角度复现一下代码。

看本节前建议各位先看一下 Batch Normalization:https://blog.csdn.net/dgvv4/article/details/130567501

Layer Normalization 的论文地址如下:https://arxiv.org/pdf/1607.06450.pdf


1. 原理介绍

深层网络训练时,网络层数的增加会增加模型计算负担,同时也会导致模型变得难以训练。随着网络层数的增加,数据的分布方式也会随着层与层之间的变化而变化,这种现象被称为内部协变量偏移(Internal Convariate Shift, ICS)。这要求模型训练时必须使用较小的学习率,且需要慎重地选择权重初值ICS 导致训练速度减慢,同时也导致使用饱和的非线性激活函数(如sigmoid,正负两边都会饱和梯度为 0)时出现梯度消失问题。

为解决内部协方差变化(ICS),思路是固定每一层输出的均值和方差,即层归一化算法(Layer Normalization,LN),层归一化算法用每个样本的均值和方差对输入进行归一化LN 是在单个样本上操作,可以应用于小批次和 RNN。LN 和 BN 有相同的形式,只是不同的归一化方式。

层归一化与批归一化算法的区别只在于统计值的获取方式上,下式是层归一化算法中均值和方差的计算方式,H 表示层的隐藏单元数,a_i^l 代表第 l 层中的第 i 个神经元。

u^l=\frac{1}{H}\sum\limits_{i=1}^H a_i^l

\sigma^l=\sqrt{\frac{1}{H}\sum\limits_{i=1}^H\left(a_i^l-u^l\right)^2}

层归一化算法通过计算在一个训练样本上某一层所有的神经元的均值和方差来对输入进行归一化,像批归一化算法那样,同样也给每个神经元加入了增益 \gamma 和偏置 \beta 来实现线性变换,这在归一化后激活函数前使用。

层归一化 LN 和批归一化 BN 不同的是,层归一化在训练和测试时执行同样的计算,由于 LN 与批次大小没有关系,LN 能够在递归神经网络的每个时间步上分别计算归一化操作所需要的均值和方差的值。实验结果表明,层归一化技术相对批归一化技术训练时间更短

层归一化算法比较适合应用于全连接网络和递归神经网络,有学者尝试在卷积神经网络上采用层归一化算法,但是发现层归一化算法的效果没有批归一化算法好。这是因为对于全连接层,隐藏层中的全部单元对最终预测和重新定位做出相似的贡献,将所有输入缩放到一个图层效果很好。但是,类似贡献的假设在卷积神经网络不再适用,大量的隐藏单位的感受野位于图像边界附近很少被打开,因此来自同一层内其他隐藏单元的统计数据有很大不同。有学者认为认为需要进一步的研究使卷积网络中的层归一化工作取得好的效果。

总体来说,LN 较 BN 简单,它也是通过减少 ICS 来加速神经网络的训练。LN 在训练和测试时没有区别,只需要对当前隐藏层计算均值和方差而不需要保存每层的移动平均和方差用于测试且不受批次大小的限制,可以通过在线学习的方式一条一条的输入训练数据。

优点:批量较小时,效果好;适用于自然语言处理任务。

缺点:批量较大时,效果不如BN。


2. 代码展示

构造一个输入 shape=[B,C,H*W] 的张量,对每个样本在 [C, H*W] 这两个维度上做 LN

import torch
from torch import nn

class LN(nn.Module):
    # 初始化
    def __init__(self, normalized_shape,  # 在哪个维度上做LN
                 eps:float = 1e-5, # 防止分母为0
                 elementwise_affine:bool = True):  # 是否使用可学习的缩放因子和偏移因子
        super(LN, self).__init__()
        # 需要对哪个维度的特征做LN, torch.size查看维度
        self.normalized_shape = normalized_shape  # [c,w*h]
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        # 构造可训练的缩放因子和偏置
        if self.elementwise_affine:  
            self.gain = nn.Parameter(torch.ones(normalized_shape))  # [c,w*h]
            self.bias = nn.Parameter(torch.zeros(normalized_shape))  # [c,w*h]

    # 前向传播
    def forward(self, x: torch.Tensor): # [b,c,w*h]
        # 需要做LN的维度和输入特征图对应维度的shape相同
        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]  # [-2:]
        # 需要做LN的维度索引
        dims = [-(i+1) for i in range(len(self.normalized_shape))]  # [b,c,w*h]维度上取[-1,-2]维度,即[c,w*h]
        # 计算特征图对应维度的均值和方差
        mean = x.mean(dim=dims, keepdims=True)  # [b,1,1]
        mean_x2 = (x**2).mean(dim=dims, keepdims=True)  # [b,1,1]
        var = mean_x2 - mean**2  # [b,c,1,1]
        x_norm = (x-mean) / torch.sqrt(var+self.eps)  # [b,c,w*h]
        # 线性变换
        if self.elementwise_affine:
            x_norm = self.gain * x_norm + self.bias  # [b,c,w*h]
        return x_norm

# ------------------------------- #
# 验证
# ------------------------------- #

if __name__ == '__main__':

    x = torch.linspace(0, 23, 24, dtype=torch.float32)  # 构造输入层
    x = x.reshape([2,3,2*2])  # [b,c,w*h]
    # 实例化
    ln = LN(x.shape[1:])
    # 前向传播
    x = ln(x)
    print(x.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/505311.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用友携国资国企走进浙江龙游,共探区县国资智慧监管新样板

近日,由龙游县国有资产经营有限公司指导,用友网络科技股份有限公司(以下简称:用友网络)主办的“成为数智企业 迈向高质量发展——2023走进龙游数智化观摩研讨会”在浙江龙游成功举办!全国近百位国资国企负责…

Cocos Creator 3.x 热更新,使用chatgpt快速定位解决问题

为什么要使用app热更 使用 app 热更的主要原因是可以快速地向用户推送应用程序的更新版本,同时也可以减少应用程序更新时需要用户手动下载和安装的次数,从而提高用户体验和应用程序的可维护性。以下是一些使用 app 热更的好处: 快速发布更新…

react初始化配置rem,less,@,本地代理,通配符,视口单位等

初始化项目之后,项目配置中默认配置的是scss 想用less就需要单独配置了,在做一个完整的项目情况下create-react-app搭出来架子的配置往往是不够的至少需要简单配置以下信息 暴露webpack之后会增加很多文件和依赖配置,有些时候并不想把它暴露出…

阿里云镜像区别公共镜像、自定义、共享、云市场和社区镜像介绍

阿里云服务器镜像根据来源不同分为公共镜像、自定义镜像、共享镜像、云市场镜像和社区镜像,一般没有特殊情况选择公共镜像,公共镜像是阿里云官网提供的正版授权操作系统,云市场镜像是在纯净版操作系统的基础上预装了相关软件及运行环境&#…

自动修改文章的软件-文章原创软件

免费版自动修改文章的软件 免费版自动修改文章的软件是一种又快速、易用且免费的文章修改软件,可以帮助用户批量修改文章和图文,并为用户提供高质量的修改服务。用户仅需上传待修改的文章文件,软件就能自动检测出文章中的语法、拼写错误和表…

开发人员如何理解《辟邪剑谱》的“前8个字”

辟邪剑谱可以说是武林至宝,人人都想得到,让自己冲破三流侠客的行列。得到的人,心里激动不已,得等到四下无人的时候才敢偷偷去练。但奈何最前面有8个字被折叠起来了,很多人也曾得到过,但一直没看到这前8个字…

【shell函数】

目录 一、shell函数1、shell函数的定义 二、函数传参三、阶乘四、函数实验题目 一、shell函数 使用函数可以避免代码重复 使用函数可以将大的工程分割为若干小的功能模块,代码的可读性更强 1、shell函数的定义 函数返回值: return表示退出函数并返回一个退出值&…

Linux | 学习笔记(适合小白)

操作系统概述: 计算机是由硬件和软件这两个主要部分组成的操作系统是软件的一类,主要作用是协助用户调度硬件工作,充当用户和计算机硬件之间的桥梁常见的操作系统:PC端:Windows,Linux,MacOS&…

ShardingSphere系列一(MySQL主从架构及读写分离实战(搭建主从集群、MySQL高可用方案MHA、分库分表概念))

文章目录 1. 搭建主从集群1.1 概念1.2 同步的原理1.3 搭建主从同步实战1.3.1 配置master主库1.3.2 配置slave从库1.3.3 主从集群测试 1.4 主从同步扩展1.4.1 主库同步与部分同步(同步范围限制)1.4.2 读写分离配置1.4.3 其他集群方式 1.5 GTID同步集群1.6…

Linux的这七大认识误区,你千万别有!

导读本文罗列了大家对Linux的七大认识误区,看看其中那个是你也出现过的。千万别让这些先入为主的观点断送了你体验新事物的机会。 Linux的受众群体并不大。对还是错? 错!大错而特错。 我承认,Linux的实际用户数量很难统计,因为…

【1.JS基础-JavaScript的基本语法和数据类型】

1.JavaScript的编写方式 2 JavaScript的交互方式 3 Chrome的调试工具 4 变化数据的记录 – 变量 如果一个变量有声明,但是没有赋值,那么默认值是undefined 5 JavaScript的数据类型 typeof操作符 6 Number类型 number 类型代表整数和浮点数。 ◼ isNaN…

举一反三学python(12)—制作简易计算器

下图为简易、实用的计算器的效果图,今天展示用百行代码完成。 一、导入模块 import tkinter as tk 二、整体布局 win tk.Tk() # 实例化一个窗体对象 win.title(简易计算器) # 窗口标题 win.geometry(295x280) # 窗…

ESP8266图形播放器 + 天气时钟显示项目更新

<fontcolor=green>ESP8266图形播放器 + 天气时钟显示项目更新 🎞原项目播放效果演示:https://www.ixigua.com/6968269356820070912?logTag=f37e7f1f5cefa9876746✨由于有些库的更新以及API调用接口的失效,特此更新,内容上做了精简和优化。⚡由于所调用的库比较多,…

IPWorks VoIP 2022.0.8505 C++ Edition

IPWorks VoIP IPWorks VoIP 2022 C Edition 支持常见 SIP 和 IVR 操作的简单 VoIP 库。 网络语音组件 IPWorks VoIP 提供 SIP 和 IVR 组件&#xff0c;旨在促进 CTI 应用中的常见 VoIP 操作。快速集成功能&#xff0c;以根据您的自定义 IVR 菜单建立拨出呼叫、接听来电和路由呼…

kt:reified和sam转换(Single Abstract Method Conversions)

什么是refied关键字 ​由于我们都知道Kotlin和Java一样都存在着泛型擦除问题&#xff0c;而Kotlin它知道Java所带来的这个问题&#xff0c;所以对此Kotlin留了一个后门&#xff0c;就是通过inline函数保证使得泛型类的类型实参在运行时能够保留&#xff0c;这样的操作 Kotlin 中…

使用OpenCV部署全景驾驶感知网络YOLOP

开源项目 MCnet 是一个神经网络模型&#xff0c;用于实现车辆视觉感知的任务&#xff0c;比如车道线检测、行驶区域分割和物体检测等。MCnet 的全称是 Multitask CNN&#xff0c;它在单个神经网络模型中集成了多个任务的网络结构&#xff0c;可以同时对输入图像进行多个任务的…

Day964.从持续构建到持续集成 -遗留系统现代化实战

从持续构建到持续集成 Hi&#xff0c;我是阿昌&#xff0c;今天学习记录的是关于从持续构建到持续集成的内容。 如何修改后的代码可以“火速”部署到生产环境里&#xff0c;这样才能提高整个端到端的交付效率&#xff0c;让每次改动工作都能及时得到反馈&#xff0c;尽快验证…

看火山引擎DataLeap如何做好电商治理(二):案例分析与解决方案

接上篇&#xff0c;以短视频优质项目为例&#xff0c;火山引擎DataLeap平台治理团队会去对每天发布的这种挂购物车车短视频打上标签&#xff0c;识别这些短视频它是优质的还是低质的&#xff0c;以及具体原因。一个视频经过这个模型识别之后&#xff0c;会给到奖惩中心去做相应…

聊一聊 用 dotnet-trace 调查 lock锁竞争

一&#xff1a;背景 1. 讲故事 最近在分析一个 linux 上的 dump&#xff0c;最后的诱因是大量的lock锁诱发的高频上下文切换&#xff0c;虽然问题告一段落&#xff0c;但我还想知道一点信息&#xff0c;所谓的高频到底有多高频&#xff1f;锁竞争到底是一个怎样的锁竞争&…

将训练好的模型保存在服务端的三种办法

刚刚在完善我书中第七章案例的文档时&#xff0c;需要将训练好的模型存储在服务端&#xff0c;方便小伙伴们来使用该模型&#xff0c;这里我提供三种办法&#xff1a; 直接从我的个人网站中加载&#xff1b;通过python启动一个文件下载服务器&#xff1b;使用微信小程序云存储…