Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)

news2024/11/25 9:36:22

Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)

flyfish

目录

  • Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)
    • 先看LayerNorm和BatchNorm
    • 举个例子计算 LayerNorm
    • RMSNorm 的整个计算过程
      • 实际代码实现
      • 结果

先看LayerNorm和BatchNorm

展示计算的方向
在这里插入图片描述

  • axis=0 代表第一个轴,逐列处理数据。
  • axis=1 代表第二个轴,逐行处理数据。在二维数组中,axis=-1 等同于 axis=1。
  • axis=-1 代表最后一个轴。在二维数组中,axis=-1 等同于 axis=1,即最后一个轴。

在二维的情况 下,BatchNorm是按列算,LayerNorm按行算

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

class CustomLayerNorm:
    def __init__(self, eps=1e-5):
        self.eps = eps

    def __call__(self, x):
        mean = np.mean(x, axis=-1, keepdims=True)
        std = np.std(x, axis=-1, keepdims=True)
        normalized = (x - mean) / (std + self.eps)
        return normalized

class CustomBatchNorm:
    def __init__(self, eps=1e-5):
        self.eps = eps

    def __call__(self, x):
        mean = np.mean(x, axis=0)
        std = np.std(x, axis=0)
        normalized = (x - mean) / (std + self.eps)
        return normalized

# Original Data
data = np.array([[1.0, 2.0, 3.0],
                 [4.0, 5.0, 6.0],
                 [7.0, 8.0, 9.0]])

# Apply Custom LayerNorm
custom_layer_norm = CustomLayerNorm()
custom_layer_norm_data = custom_layer_norm(data)

# Apply Custom BatchNorm
custom_batch_norm = CustomBatchNorm()
custom_batch_norm_data = custom_batch_norm(data)

# Apply PyTorch LayerNorm
data_tensor = torch.tensor(data, dtype=torch.float32)
layer_norm = nn.LayerNorm(data_tensor.size()[1:])
pytorch_layer_norm_data = layer_norm(data_tensor).detach().numpy()

# Compare Custom and PyTorch LayerNorm
print("Original Data:\n", data)
print("Custom LayerNorm Data:\n", custom_layer_norm_data)
print("PyTorch LayerNorm Data:\n", pytorch_layer_norm_data)
Original Data:
 [[1. 2. 3.]
 [4. 5. 6.]
 [7. 8. 9.]]
Custom LayerNorm Data:
 [[-1.22472987  0.          1.22472987]
 [-1.22472987  0.          1.22472987]
 [-1.22472987  0.          1.22472987]]
PyTorch LayerNorm Data:
 [[-1.2247356  0.         1.2247356]
 [-1.2247356  0.         1.2247356]
 [-1.2247356  0.         1.2247356]]

举个例子计算 LayerNorm

具体步骤如下:

  1. 计算每行的均值
  • 对每一行,计算其均值。
  • 第1行: mean = (1 + 2 + 3) / 3 = 2
  • 第2行: mean = (4 + 5 + 6) / 3 = 5
  • 第3行: mean = (7 + 8 + 9) / 3 = 8
  1. 计算每行的标准差
  • 对每一行,计算其标准差。
  • 第1行: s t d = s q r t ( ( ( 1 − 2 ) 2 + ( 2 − 2 ) 2 + ( 3 − 2 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((1-2)^2 + (2-2)^2 + (3-2)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((12)2+(22)2+(32)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  • 第2行: s t d = s q r t ( ( ( 4 − 5 ) 2 + ( 5 − 5 ) 2 + ( 6 − 5 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((4-5)^2 + (5-5)^2 + (6-5)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((45)2+(55)2+(65)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  • 第3行: s t d = s q r t ( ( ( 7 − 8 ) 2 + ( 8 − 8 ) 2 + ( 9 − 8 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((7-8)^2 + (8-8)^2 + (9-8)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((78)2+(88)2+(98)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  1. 标准化每一行
  • 对每一行,使用均值和标准差进行标准化。公式为: ( x − m e a n ) / ( s t d + e p s ) (x - mean) / (std + eps) (xmean)/(std+eps)。其中 eps 是一个小常数,防止除零,通常取值为 1e-5。
  • 计算结果如下:

标准化公式: n o r m a l i z e d = ( x − m e a n ) / ( s t d + e p s ) normalized = (x - mean) / (std + eps) normalized=(xmean)/(std+eps)

第1行: 
[(1-2)/(0.8165+1e-5), (2-2)/(0.8165+1e-5), (3-2)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]

第2行: 
[(4-5)/(0.8165+1e-5), (5-5)/(0.8165+1e-5), (6-5)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]

第3行: 
[(7-8)/(0.8165+1e-5), (8-8)/(0.8165+1e-5), (9-8)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]

最终标准化结果矩阵为:

[[-1.2247, 0, 1.2247]
 [-1.2247, 0, 1.2247]
 [-1.2247, 0, 1.2247]]

RMSNorm 的整个计算过程

Meta Llama 3 使用了RMSNorm
假设我们有以下 2D 输入张量 X X X(为了简单起见,我们假设这个张量有 2 行 3 列):
[ 1 2 3 4 5 6 ] \begin{bmatrix}1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} [142536]
RMSNorm 的计算过程如下:

  1. 计算每行的均方根 (RMS)
    首先,对于每一行,我们计算该行元素的平方和的均值,然后取其平方根。
    对于第 1 行:
    RMS row1 = 1 2 + 2 2 + 3 2 3 = 1 + 4 + 9 3 = 4.67 ≈ 2.16 \text{RMS}_{\text{row1}} = \sqrt{\frac{1^2 + 2^2 + 3^2}{3}} = \sqrt{\frac{1 + 4 + 9}{3}} = \sqrt{4.67} \approx 2.16 RMSrow1=312+22+32 =31+4+9 =4.67 2.16
    对于第 2 行:
    RMS row2 = 4 2 + 5 2 + 6 2 3 = 16 + 25 + 36 3 = 25.67 ≈ 5.07 \text{RMS}_{\text{row2}} = \sqrt{\frac{4^2 + 5^2 + 6^2}{3}} = \sqrt{\frac{16 + 25 + 36}{3}} = \sqrt{25.67} \approx 5.07 RMSrow2=342+52+62 =316+25+36 =25.67 5.07
  2. 使用均方根对输入进行归一化
    将每行的元素除以该行的 RMS 值。这里的 epsilon 用于防止除以零的问题,我们假设 ϵ = 1 e − 6 \epsilon = 1e-6 ϵ=1e6
    对于第 1 行: Normed row1 = [ 1 2.16 + ϵ 2 2.16 + ϵ 3 2.16 + ϵ ] ≈ [ 0.462 0.925 1.387 ] \text{Normed}_{\text{row1}} = \begin{bmatrix} \frac{1}{2.16 + \epsilon} & \frac{2}{2.16 + \epsilon} & \frac{3}{2.16 + \epsilon} \end{bmatrix} \approx \begin{bmatrix} 0.462 & 0.925 & 1.387 \end{bmatrix} Normedrow1=[2.16+ϵ12.16+ϵ22.16+ϵ3][0.4620.9251.387]
    对于第 2 行: Normed row2 = [ 4 5.07 + ϵ 5 5.07 + ϵ 6 5.07 + ϵ ] ≈ [ 0.789 0.986 1.183 ] \text{Normed}_{\text{row2}} = \begin{bmatrix} \frac{4}{5.07 + \epsilon} & \frac{5}{5.07 + \epsilon} & \frac{6}{5.07 + \epsilon} \end{bmatrix} \approx \begin{bmatrix} 0.789 & 0.986 & 1.183 \end{bmatrix} Normedrow2=[5.07+ϵ45.07+ϵ55.07+ϵ6][0.7890.9861.183]
  3. 应用可学习的缩放参数
    假设权重参数 weight \text{weight} weight 为一个向量 [ 1 , 1 , 1 ] [1, 1, 1] [1,1,1],表示每个元素的缩放因子。对于第 1 行: Output row1 = [ 0.462 ⋅ 1 0.925 ⋅ 1 1.387 ⋅ 1 ] = [ 0.462 0.925 1.387 ] \text{Output}_{\text{row1}} = \begin{bmatrix} 0.462 \cdot 1 & 0.925 \cdot 1 & 1.387 \cdot 1 \end{bmatrix} = \begin{bmatrix} 0.462 & 0.925 & 1.387 \end{bmatrix} Outputrow1=[0.46210.92511.3871]=[0.4620.9251.387]对于第 2 行: Output row2 = [ 0.789 ⋅ 1 0.986 ⋅ 1 1.183 ⋅ 1 ] = [ 0.789 0.986 1.183 ] \text{Output}_{\text{row2}} = \begin{bmatrix} 0.789 \cdot 1 & 0.986 \cdot 1 & 1.183 \cdot 1 \end{bmatrix} = \begin{bmatrix} 0.789 & 0.986 & 1.183 \end{bmatrix} Outputrow2=[0.78910.98611.1831]=[0.7890.9861.183]

实际代码实现

以下是使用 PyTorch 实现上述步骤的代码示例:

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

# 示例数据
data = torch.tensor([[1.0, 2.0, 3.0],
                     [4.0, 5.0, 6.0]])

# 实例化 RMSNorm 层
rms_norm = RMSNorm(dim=data.size(-1))

# 计算归一化后的输出
normalized_data = rms_norm(data)

print("Original Data:\n", data)
print("RMSNorm Normalized Data:\n", normalized_data)

结果

运行上述代码后,我们将得到归一化后的数据:

 tensor([[1., 2., 3.],
        [4., 5., 6.]])
RMSNorm Normalized Data:
 tensor([[0.4629, 0.9258, 1.3887],
        [0.7895, 0.9869, 1.1843]], grad_fn=<MulBackward0>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1795886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

webpack-性能优化-提取css

CDN 分发网络 Content Delivery Network 或 Content Distribution Network 的缩写 一般把静态资源或第三方资源放到CDN上。 可以在 output的publicPath配置cdn的地址&#xff0c;打包后所有的脚本的前缀都变为这个cdn地址了,一般不会这样使用 output: {filename: "[name…

Mysql学习(三)——SQL通用语法之DML

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 DML添加数据修改数据删除数据 总结 DML DML用来对数据库中表的数据记录进行增删改操作。 添加数据 -- 给指定字段添加数据 insert into 表名(字段1&#xff0c;字…

20240603每日通信--------springboot使用netty-socketio集成即时通信WebSocket

简单效果图 群聊&#xff0c;私聊&#xff0c;广播都可以支持。 基础概念&#xff1a; springbootnetty-socketioWebSocket POM文件&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/…

英伟达市值超越苹果;ChatGPT、Perplexity、Claude 同时大崩溃丨 RTE 开发者日报 Vol.220

开发者朋友们大家好&#xff1a; 这里是 「RTE 开发者日报」 &#xff0c;每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE&#xff08;Real-Time Engagement&#xff09; 领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「…

大数据的数据采集

大数据采集是指从各种来源收集大量数据的过程&#xff0c;这些数据通常是结构化或非结构化的&#xff0c;并且可能来自不同的平台、设备或应用程序。大数据采集是大数据分析和处理的第一步&#xff0c;对于企业决策、市场分析、产品改进等方面具有重要意义。以下是大数据采集的…

关于python包导入问题的重思考

将顶层目录直接设置为一个包 像这样&#xff0c;每一个文件从顶层包开始导入 这样可以解决我的问题&#xff0c;但是要注意的时&#xff0c;要避免使用出现上下级出现同名包的情况&#xff0c;比如&#xff1a; AutoServer--AutoServer--__init__.py--__init__.py这种情况下…

MongoDB CRUD操作:地理位置查询

MongoDB CRUD操作&#xff1a;地理位置查询 文章目录 MongoDB CRUD操作&#xff1a;地理位置查询地理空间数据GeoJSON对象传统坐标对通过数组指定&#xff08;首选&#xff09;通过嵌入文档指定 地理空间索引2dsphere2d 地理空间查询地理空间查询运算符地理空间聚合阶段 地理空…

Kaggle——Deep Learning(使用 TensorFlow 和 Keras 为结构化数据构建和训练神经网络)

1.单个神经元 创建一个具有1个线性单元的网络 #线性单元 from tensorflow import keras from tensorflow.keras import layers #创建一个具有1个线性单元的网络 modelkeras.Sequential([layers.Dense(units1,input_shape[3]) ]) 2.深度神经网络 构建序列模型 #构建序列模型 …

【vue3|第6期】如何正确地更新和替换响应式对象reactive

日期&#xff1a;2024年6月5日 作者&#xff1a;Commas 签名&#xff1a;(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释&#xff1a;如果您觉得有所帮助&#xff0c;帮忙点个赞&#xff0c;也可以关注我&#xff0c;我们一起成长&#xff1b;如果有不对的地方&#xff…

【Linux取经路】守护进程

文章目录 一、前台进程和后台进程二、Linux 的进程间关系三、setsid——将当前进程设置为守护进程四、daemon——设置为守护进程五、结语 一、前台进程和后台进程 Linux 中每一次用户登录都是一个 session&#xff0c;一个 session 中只能有一个前台进程在运行&#xff0c;键盘…

AppInventor2有没有删除后的撤销功能?

问&#xff1a;不小心删除了组件&#xff0c;能撤回吗&#xff1f; 答&#xff1a;界面&#xff08;组件&#xff09;设计界面&#xff0c;没有撤销功能。代码&#xff08;逻辑&#xff09;设计视图&#xff0c;可以使用 CtrlZ 撤销&#xff0c;CtrlY 反撤销。 界面设计没有撤…

搜索与图论:树的重心

搜索与图论&#xff1a;树的重心 题目描述参考代码 题目描述 输入样例 9 1 2 1 7 1 4 2 8 2 5 4 3 3 9 4 6输出样例 4参考代码 #include <cstring> #include <iostream> #include <algorithm>using namespace std;const int N 100010, M N * 2;int n, m…

JavaWeb_SpringBootWeb案例

环境搭建&#xff1a; 开发规范 接口风格-Restful&#xff1a; 统一响应结果-Result&#xff1a; 开发流程&#xff1a; 第一步应该根据需求定义表结构和定义接口文档 注意&#xff1a; 本文代码从上往下一直添加功能&#xff0c;后面的模块下的代码包括前面的模块&#xff0c…

新能源管理系统主要包括哪些方面的功能?

随着全球对可持续发展和环境保护的日益重视&#xff0c;新能源管理系统已成为现代能源领域的核心组成部分。这一系统不仅涉及对新能源的收集、存储和管理&#xff0c;还包括对整个能源网络进行高效、智能的监控和控制。以下是新能源管理系统主要包含的几方面功能&#xff1a; 一…

ESP32 Error creating RestrictedPinnedToCore

随缘记&#xff0c;刚遇到&#xff0c;等以后就可能不想来写笔记了。 目前要使用到音频数据&#xff0c;所以去用ESP-ADF&#xff0c;但在使用例程上出现了这个API有问题&#xff0c;要去打补丁。 但是我打补丁的时候git bash里显示not apply&#xff0c;不能打上。 网上看到…

谷歌账号的注册到使用GitHub

一、浏览器扩展 浏览器扩展谷歌学术 二、注册谷歌邮箱 https://support.google.com/accounts/answer/27441?hlzh-hans 1.打开无痕模式&#xff08;ctrlshiftn&#xff09; 2.输入网址 3.选择个人账号 4.填写信息&#xff08;随便填就行&#xff09; &#xff08;以上步骤有时…

FTP

文章目录 概述主动模式和被动模式的工作过程注意事项 概述 文件传输协议 FTP&#xff08;File Transfer Protocol&#xff09;在 TCP/IP 协议族中属于应用层协议&#xff0c;是文件传输标准。主要功能是向用户提供本地和远程主机之间的文件传输&#xff0c;尤其在进行版本升级…

【YOLOV8】2.目标检测-训练自己的数据集

Yolo8出来一段时间了,包含了目标检测、实例分割、人体姿态预测、旋转目标检测、图像分类等功能,所以想花点时间总结记录一下这几个功能的使用方法和自定义数据集需要注意的一些问题,本篇是第二篇,目标检测功能,自定义数据集的训练。 YOLO(You Only Look Once)是一种流行的…

基于element ui 城市选择之间的级联选择

通过el-select实现城市的级联选择效果如图所示 代码实现 <template><div><el-form :model"ruleForminfo"><el-form-item label"居住地址" required><el-col :span"6"><el-form-item ><el-select v-mode…

tsconfig.json和tsconfig.app.json文件解析(vue3+ts+vite)

tsconfig.json {"files": [],"references": [{"path": "./tsconfig.node.json"},{"path": "./tsconfig.app.json"}] }https://www.typescriptlang.org/tsconfig/#files files: 在这个例子中&#xff0c;files 数…