Pytorch中layernorm实现详解

news2025/3/20 8:28:14

      平时我们在编写神经网络时,经常会用到layernorm这个函数来加快网络的收敛速度。那layernorm到底在哪个维度上进行归一化的呢? 

一、问题描述

    首先借用知乎上的一张图,原文写的也非常好,大家有空可以去阅读一下,链接放在参考文献里了。如左图所示,假设现在输入的维度是(bs,seq_len, embedding),其中bs代表batch_size, seq_len代表序列长度 ,embedding表示嵌入大小。

    那在layernorm时,我们是对(seq_len, embedding)这个矩阵取均值和方差(上图);还是只对embedding这个维度取均值和方差呢(下图)?前者会得到bs个均值和方差,而后者会得到bs * seq_len 个均值和方差。下面我们进行编程验证。

二、编程实现

import torch

batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)

layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
print("用pytorch的layer_norm所得结果\n", layer_norm(embedding))

print("自己编写layer_norm所得结果")
eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

结果:

用pytorch的layer_norm所得结果
 tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],
         [ 0.1144, -0.6476,  1.5753, -1.0421],
         [-1.0278, -0.7498,  0.2559,  1.5218]],

        [[-1.0527, -0.8723,  1.3354,  0.5895],
         [-0.6403, -1.1399,  1.4842,  0.2961],
         [ 0.7352, -0.8236, -1.1342,  1.2226]]])
自己编写layer_norm所得结果
mean:  torch.Size([2, 3, 1])
y_custom:  tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],
         [ 0.1144, -0.6476,  1.5753, -1.0421],
         [-1.0278, -0.7498,  0.2559,  1.5218]],

        [[-1.0527, -0.8723,  1.3354,  0.5895],
         [-0.6403, -1.1399,  1.4842,  0.2961],
         [ 0.7352, -0.8236, -1.1342,  1.2226]]])

结果的相等的。可以看到,我们在取均值和方差时,是对最后一个维度取的。所以我们会得到 (N,C)个均值与方差。假设二是正确的。 

而实际上这种实现方法和Instance Norm是相同的

from torch.nn import InstanceNorm2d
instance_norm = InstanceNorm2d(3, affine=False)
x = torch.randn(2, 3, 4)
output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入
print(output.reshape(2,3,4))

layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)
print(layer_norm(x))

结果:

tensor([[[ 0.7475, -1.7061,  0.6676,  0.2910],
         [ 0.1144, -0.6476,  1.5753, -1.0421],
         [-1.0278, -0.7498,  0.2559,  1.5218]],

        [[-1.0527, -0.8723,  1.3354,  0.5895],
         [-0.6403, -1.1399,  1.4842,  0.2961],
         [ 0.7352, -0.8236, -1.1342,  1.2226]]])
tensor([[[ 0.1293, -1.0034,  1.5760, -0.7018],
         [-1.3981, -0.4828,  1.0876,  0.7933],
         [-1.7034,  0.8545,  0.4876,  0.3612]],

        [[-1.4750,  1.2212, -0.2607,  0.5144],
         [ 0.7017, -0.8350,  1.2502, -1.1169],
         [-1.7273,  0.6965,  0.5147,  0.5161]]])

三、参考文献

(45 封私信 / 80 条消息) 为什么Transformer要用LayerNorm? - 知乎 (zhihu.com)https://www.zhihu.com/question/487766088/answer/2644783144

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2318251.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于java的ssm+JSP+MYSQL的高校四六级报名管理系统(含LW+PPT+源码+系统演示视频+安装说明)

作者:计算机搬砖家 开发技术:SpringBoot、php、Python、小程序、SSM、Vue、MySQL、JSP、ElementUI等,“文末源码”。 专栏推荐:SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:Java精选实战项…

ns3使用入门_基于ns3.44_Part2_配置模块参数的Configuration 和Attributes

前言 事实上ns3的官方手册很全,相关书籍也是有的,官网先贴在这里: ns-3 | a discrete-event network simulator for internet systemsa discrete-event network simulator for internet systemshttps://www.nsnam.org/相关的脚本介绍也都有一些: ns-3.35_wifi-he-networ…

性能测试过程实时监控分析

性能监控 前言一、查看性能测试结果的3大方式1、GUI界面报告插件2、命令行运行 html报告3、后端监听器接入仪表盘 二、influxDB grafana jmeter测试监控大屏1、原理:2、linux环境中influxDB 安装和配置3、jmerer后端监听器连接influxDB4、linux环境总grafana环境搭…

C程序设计(第五版)及其参考解答,附pdf

通过网盘分享的文件:谭浩强C语言设计 链接: https://pan.baidu.com/s/1U927Col0XtWlF9TsFviApg?pwdeddw 提取码: eddw 谭浩强教授的《C程序设计》是C语言学习领域的经典教材,其内容深入浅出,适合不同层次的学习者。 一、教材版本与特点 最…

杰理科技JL703N双模蓝牙芯片—云信

杰理科技JL703N芯片运算能力、接收灵敏度、发射功率、音频性能等指标均处于行业一流水平,能满足多场景的应用需求,具有以下明显优势: 一、高性能双核浮点CPU,算力十足 JL703N芯片搭载了32位高性能双核CPU,主频高达32…

Rust + 时序数据库 TDengine:打造高性能时序数据处理利器

引言:为什么选择 TDengine 与 Rust? TDengine 是一款专为物联网、车联网、工业互联网等时序数据场景优化设计的开源时序数据库,支持高并发写入、高效查询及流式计算,通过“一个数据采集点一张表”与“超级表”的概念显著提升性能…

Nvidia 官方CUDA课程学习笔记

之前心血来潮学习了一下Nvidia CUDA,外行,文章有理解不当的地方,望指正。 主要根据以下Nvidia官方课程学习: https://www.bilibili.com/video/BV1JJ4m1P7xW/?spm_id_from333.337.search-card.all.click&vd_sourcec256dbf86b…

【AI News | 20250319】每日AI进展

AI Repos 1、XianyuAutoAgent 实现了 24 小时自动化值守的 AI 智能客服系统,支持多专家协同决策、智能议价和上下文感知对话,让我们店铺管理更轻松。主要功能: 智能对话引擎,支持上下文感知和专家路由阶梯降价策略,自…

一种基于大规模语言模型LLM的数据分析洞察生成方法

从复杂数据库中提取洞察对数据驱动决策至关重要,但传统手动生成洞察的方式耗时耗力,现有自动化数据分析方法生成的洞察不如人工生成的有洞察力,且存在适用场景受限等问题。下文将介绍一种新的方法,通过生成高层次问题和子问题,并使用SQL查询和LLM总结生成多表数据库中的见…

【npm ERR! code ERESOLVE npm ERR! ERESOLVE unable to resolve dependency tree】

npm ERR! code ERESOLVE npm ERR! ERESOLVE unable to resolve dependency tree npm ERR! code ERESOLVE npm ERR! ERESOLVE unable to resolve dependency tree 当我们拿到一个前端项目的时候,想要把它运行起来,首先是要给它安装依赖,即cd到…

用 pytorch 从零开始创建大语言模型(四):从零开始实现一个用于生成文本的GPT模型

从零开始创建大语言模型(Python/pytorch )(四):从零开始实现一个用于生成文本的GPT模型 4 从零开始实现一个用于生成文本的GPT模型4.1 编写 L L M LLM LLM架构4.2 使用层归一化对激活值进行标准化4.3 使用GELU激活函数…

【新能源汽车“心脏”赋能:三电系统研发、测试与应用匹配的恒压恒流源技术秘籍】

新能源汽车“心脏”赋能:三电系统研发、测试与应用匹配的恒压恒流源技术秘籍 在新能源汽车蓬勃发展的浪潮中,三电系统(电池、电机、电控)无疑是其核心驱动力。而恒压源与恒流源,作为电源管理的关键要素,在…

目标检测20年(一)

今天看的文献是《Object Detection in 20 Years: A Survey》,非常经典的一篇目标检测文献,希望通过这篇文章学习到目标检测的基础方法并提供一些创新思想。 论文链接:1905.05055 一、摘要 1.1 原文 Object detection, as of one the most…

【MySQL数据库】存储过程与自定义函数(含: SQL变量、分支语句、循环语句 和 游标、异常处理 等内容)

存储过程:一组预编译的SQL语句和流程控制语句,被命名并存储在数据库中。存储过程可以用来封装复杂的数据库操作逻辑,并在需要时进行调用。 类似的操作还有:自定义函数、.sql文件导入。 我们先从熟悉的函数开始说起: …

WEB攻防-PHP反序列化-字符串逃逸

目录 前置知识 字符串逃逸-减少 字符串逃逸-增多 前置知识 1.PHP 在反序列化时,语法是以 ; 作为字段的分隔,以 } 作为结尾,在结束符}之后的任何内容不会影响反序列化的后的结果 class people{ public $namelili; public $age20; } var_du…

英伟达GTC 2025大会产品全景剖析与未来路线深度洞察分析

【完整版】3月19日,黄仁勋Nvidia GTC 2025 主题演讲|英伟达 英伟达GTC 2025大会产品全景剖析与未来路线深度洞察分析 一、引言 1.1 分析内容 本研究主要采用了文献研究法、数据分析以及专家观点引用相结合的方法。在文献研究方面,广泛收集了…

基于java的ssm+JSP+MYSQL的九宫格日志网站(含LW+PPT+源码+系统演示视频+安装说明)

系统功能 管理员功能模块: 个人中心 用户管理 日记信息管理 美食信息管理 景点信息管理 新闻推荐管理 日志展示管理 论坛管理 我的收藏管理 管理员管理 留言板管理 系统管理 用户功能模块: 个人中心 日记信息管理 美食信息管理 景点信息…

【Java】Mybatis学习笔记

目录 一.搭建Mybatis 二.Mybatis核心配置文件解析 1.environment标签 2.typeAliases 3.mappers 三.Mybatis获取参数值 四.Mybatis查询功能 五.特殊的SQL执行 1.模糊查询 2.批量删除 3.动态设置表名 4.添加功能获取自增的主键 六.自定义映射ResultMap 1.配置文件处…

遗传算法+四模型+双向网络!GA-CNN-BiLSTM-Attention系列四模型多变量时序预测

遗传算法四模型双向网络!GA-CNN-BiLSTM-Attention系列四模型多变量时序预测 目录 遗传算法四模型双向网络!GA-CNN-BiLSTM-Attention系列四模型多变量时序预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 基于GA-CNN-BiLSTM-Attention、CNN-BiL…

中兴B860AV3.2-T/B860AV3.1-T2_S905L3-B_2+8G_安卓9.0_先线刷+后卡刷固件-完美修复反复重启瑕疵

中兴电信B860AV3.2-T/B860AV3.1-T2_晶晨S905L3-B芯片_28G_安卓9.0_先线刷后卡刷-刷机固件包,完美修复刷机后盒子反复重启的瑕疵。 这两款盒子是可以通刷的,最早这个固件之前论坛本人以及其他水友都有分享交流过不少的固件,大概都…