【llm对话系统】大模型 Llama 源码分析之归一化方法 RMS Norm

news2025/2/3 19:10:53

1. 引言

在深度学习中,归一化 (Normalization) 是一种常用的技术,它可以加速模型的训练并提高模型的性能。常见的归一化方法包括 Batch Normalization (BatchNorm)、Layer Normalization (LayerNorm) 等。Llama 模型采用了一种称为 RMS Norm 的归一化方法,它是一种对 LayerNorm 的简化和改进。

本文将深入 Llama 源码,分析 RMS Norm 的实现逻辑,并探讨其相比于其他归一化方法的优势。

2. 归一化方法回顾

2.1 Batch Normalization (BatchNorm)

BatchNorm 对每个 mini-batch 的数据进行归一化,使其均值为 0,方差为 1。它引入了两个可学习的参数:缩放因子 (scale) 和偏移因子 (shift)。

公式:

y = (x - mean(x)) / sqrt(variance(x) + epsilon) * scale + shift

优点:

  • 加速训练。
  • 具有一定的正则化效果。

缺点:

  • 依赖于 batch size,当 batch size 较小时,效果较差。
  • 不适用于 RNN 等序列模型。

2.2 Layer Normalization (LayerNorm)

LayerNorm 对每个样本的特征进行归一化,使其均值为 0,方差为 1。它也引入了两个可学习的参数:缩放因子 (scale) 和偏移因子 (shift)。

公式:

y = (x - mean(x)) / sqrt(variance(x) + epsilon) * scale + shift

优点:

  • 不依赖于 batch size。
  • 适用于 RNN 等序列模型。

缺点:

  • 计算量比 BatchNorm 略大。

3. RMS Norm 原理

RMS Norm (Root Mean Square Normalization) 可以看作是 LayerNorm 的一个特例。它只对输入进行 均方根 (Root Mean Square) 归一化,并保留了可学习的缩放因子,但 去除了偏移因子

公式:

y = x / sqrt(mean(x^2) + epsilon) * scale

其中:

  • x 是输入向量。
  • mean(x^2)x 各元素的平方的平均值。
  • epsilon 是一个很小的常数,用于防止除零错误。
  • scale 是可学习的缩放因子,通常初始化为 1。

与 LayerNorm 的比较:

  • RMS Norm 没有减去均值 (即没有中心化)。
  • RMS Norm 没有偏移因子。

4. Llama 中 RMS Norm 的实现

Llama 源码中 RMS Norm 的实现位于 llama/model.py 文件中,定义在 RMSNorm 类中:

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        初始化 RMSNorm.

        Args:
            dim: 输入的维度
            eps: 用于数值稳定的小常数
        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        执行 RMS 归一化.

        Args:
            x: 输入张量 (..., dim)

        Returns:
            归一化后的张量 (..., dim)
        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        前向传播.

        Args:
            x: 输入张量 (..., dim)

        Returns:
            归一化并缩放后的张量 (..., dim)
        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

代码解释:

  1. __init__ 函数:

    • dim:输入的维度。
    • eps:用于数值稳定的小常数,默认为 1e-6
    • weight:可学习的缩放因子,初始化为全 1 的张量。
  2. _norm 函数:

    • 计算输入 x 的均方根的倒数:torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
      • x.pow(2):计算 x 每个元素的平方。
      • .mean(-1, keepdim=True):沿着最后一个维度计算平均值,并保持维度不变。
      • torch.rsqrt():计算平方根的倒数。
    • x 与均方根的倒数相乘,实现归一化。
  3. forward 函数:

    • 调用 _norm 函数进行归一化。
    • 将归一化后的结果与可学习的 weight 相乘,进行缩放。
    • .type_as(x):将结果转换为与输入 x 相同的类型。

使用示例:

# 假设输入维度为 512
dim = 512
rms_norm = RMSNorm(dim)

# 模拟一个输入张量
x = torch.randn(1, 10, dim)

# 进行 RMS Norm 归一化
y = rms_norm(x)

print(y.shape)  # 输出: torch.Size([1, 10, 512])

5. RMS Norm 的优势

  • 计算效率高:RMS Norm 比 LayerNorm 少了均值计算和偏移操作,计算速度更快。
  • 性能相当:实验表明,RMS Norm 的性能与 LayerNorm 相当,甚至在某些任务上略有提升。
  • 更稳定:RMS Norm 对输入的缩放更加鲁棒,因为它只依赖于输入的平方的平均值,而不依赖于输入的均值。

为什么 RMS Norm 可以去掉偏移因子?

在 Transformer 架构中,通常在 RMS Norm 之后会跟一个线性层 (例如,多头注意力机制中的 Q, K, V 投影)。这个线性层可以学习到偏移的效果。因此,RMS Norm 中的偏移因子就显得多余了。

6. 总结

RMS Norm 是一种高效且有效的归一化方法,它通过对 LayerNorm 进行简化,去除了均值计算和偏移因子,提高了计算效率并保持了良好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2291404.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

洛谷 P8724 [蓝桥杯 2020 省 AB3] 限高杆

洛谷题目传送门 题目描述 某市有 n 个路口,有 m 段道路连接这些路口,组成了该市的公路系统。其中一段道路两端一定连接两个不同的路口。道路中间不会穿过路口。 由于各种原因,在一部分道路的中间设置了一些限高杆,有限高杆的路…

虚幻UE5手机安卓Android Studio开发设置2025

一、下载Android Studio历史版本 步骤1:虚幻4.27、5.0、5.1、5.2官方要求Andrd Studio 4.0版本; 5.3、5.4、5.5官方要求的版本为Android Studio Flamingo | 2022.2.1 Patch 2 May 24, 2023 虚幻官网查看对应Andrd Studiob下载版本: https:/…

JavaWeb入门-请求响应(Day3)

(一)请求响应概述 请求(HttpServletRequest):获取请求数据 响应(HttpServletResponse):设置响应数据 BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器就可访问,应用程序的逻辑和数据都存储在服务端(维护方便,响应速度一般) CS架构:Client/ser…

【Rust】18.2. 可辩驳性:模式是否会无法匹配

喜欢的话别忘了点赞、收藏加关注哦(加关注即可阅读全文),对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 18.2.1. 模式的两种形式 模式有两种形式: 可辩驳的(可失败的&…

【SLAM】于AutoDL云上GPU运行GCNv2_SLAM的记录

配置GCNv2_SLAM所需环境并实现AutoDL云端运行项目的全过程记录。 1. 引子 前几天写了一篇在本地虚拟机里面CPU运行GCNv2_SLAM项目的博客:链接,关于GCNv2_SLAM项目相关的介绍请移步此文章,本文不再重复说明。 GCNv2: Efficient Corresponde…

【自然语言处理(NLP)】基于Transformer架构的预训练语言模型:BERT 训练之数据集处理、训练代码实现

文章目录 介绍BERT 训练之数据集处理BERT 原理及模型代码实现数据集处理导包加载数据生成下一句预测任务的数据从段落中获取nsp数据生成遮蔽语言模型任务的数据从token中获取mlm数据将文本转换为预训练数据集创建Dataset加载WikiText-2数据集 BERT 训练代码实现导包加载数据构建…

41【文件名的编码规则】

我们在学习的过程中,写出数据或读取数据时需要考虑编码类型 火山采用:UTF-16 易语言采用:GBK php采用:UTF-8 那么我们写出的文件名应该是何种编码的?比如火山程序向本地写出一个“测试.txt”,理论上这个“测…

使用MATLAB进行雷达数据采集可视化

本文使用轮趣科技N10雷达,需要源码可在后台私信或者资源自取 1. 项目概述 本项目旨在通过 MATLAB 读取 N10 激光雷达 的数据,并进行 实时 3D 点云可视化。数据通过 串口 传输,并经过解析后转换为 三维坐标点,最终使用 pcplayer 进…

沙皮狗为什么禁养?

各位铲屎官们,今天咱们来聊聊一个比较敏感的话题:沙皮狗为什么会被禁养?很多人对沙皮狗情有独钟,但有些地方却明确禁止饲养这种犬种,这背后到底是什么原因呢?别急,今天就来给大家好好揭秘&#…

Dest1ny漏洞库:用友 U8 Cloud ReleaseRepMngAction SQL 注入漏洞(CNVD-2024-33023)

大家好,今天是Dest1ny漏洞库的专题!! 会时不时发送新的漏洞资讯!! 大家多多关注,多多点赞!!! 0x01 产品简介 用友U8 Cloud是用友推出的新一代云ERP,主要聚…

DeepSeek-R1模型1.5b、7b、8b、14b、32b、70b和671b有啥区别?

deepseek-r1的1.5b、7b、8b、14b、32b、70b和671b有啥区别?码笔记mabiji.com分享:1.5B、7B、8B、14B、32B、70B是蒸馏后的小模型,671B是基础大模型,它们的区别主要体现在参数规模、模型容量、性能表现、准确性、训练成本、推理成本…

#define,源文件与头文件,赋值表达式

1.#define 1.1定义 #define 是一个预处理指令,用于定义宏 宏,是预处理阶段(在编译之前)由预处理器处理的代码片段 1.2使用 1.2.1 #define 可以定义常量 #define PI 3.14159 1.2.2 #define 可以定义宏函数 #define SQUARE(x) ((…

5分钟在本地PC上使用VLLM快速启动DeepSeek-R1-Distill-Qwen-32B

5分钟在本地PC上使用VLLM快速启动DeepSeek-R1-Distill-Qwen-32B 前言环境准备所需工具创建虚拟环境安装VLLM及依赖库 模型下载安装Hugging Face CLI下载DeepSeek-R1-Distill-Qwen-32B 模型启动启动命令启动确认 模型验证发送API请求示例输出 注意事项参考链接 前言 VLLM 是一个…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.13 降维打击:扁平化操作的六种武器

1.13 降维打击:扁平化操作的六种武器 目录 #mermaid-svg-bbLxDryjxBbXe3tu {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-bbLxDryjxBbXe3tu .error-icon{fill:#552222;}#mermaid-svg-bbLxDryjxBbXe3tu…

Oracle Primavera P6 最新版 v24.12 更新 2/2

目录 一. 引言 二. P6 EPPM 更新内容 1. 用户管理改进 2. 更轻松地标准化用户设置 3. 摘要栏标签汇总数据字段 4. 将里程碑和剩余最早开始日期拖到甘特图上 5. 轻松访问审计数据 6. 粘贴数据时排除安全代码 7. 改进了状态更新卡片视图中的筛选功能 8. 直接从活动电子…

AI-on-the-edge-device - 将“旧”设备接入智能世界

人工智能无处不在,从语音到图像识别。虽然大多数 AI 系统都依赖于强大的处理器或云计算,但**边缘计算**通过利用现代处理器的功能,使 AI 更接近最终用户。 本项目演示了使用 **ESP32**(一种低成本、支持 AI 的设备)进行…

Openfga 授权模型搭建

1.根据项目去启动 配置一个 openfga 服务器 先创建一个 config.yaml文件 cd /opt/openFGA/conf touch ./config.yaml 怎么配置? 根据官网来看 openfga/.config-schema.json at main openfga/openfga GitHub 这里讲述详细的每一个配置每一个类型 这些配置有…

C++模板编程——可变参函数模板之折叠表达式

目录 1. 什么是折叠表达式 2. 一元左折 3. 一元右折 4. 二元左折 5. 二元右折 6. 后记 上一节主要讲解了可变参函数模板和参数包展开,这一节主要讲一下折叠表达式。 1. 什么是折叠表达式 折叠表达式是C17中引入的概念,引入折叠表达式的目的是为了…

ArkTS渲染控制

文章目录 if/else:条件渲染ArkUI通过自定义组件的build()函数和@Builder装饰器中的声明式UI描述语句构建相应的UI。在声明式描述语句中开发者除了使用系统组件外,还可以使用渲染控制语句来辅助UI的构建,这些渲染控制语句包括控制组件是否显示的条件渲染语句,基于数组数据快…

UbuntuWindows双系统安装

做系统盘: Ubuntu20.04双系统安装详解(内容详细,一文通关!)_ubuntu 20.04-CSDN博客 ubuntu系统调整大小: 调整指南: 虚拟机中的Ubuntu扩容及重新分区方法_ubuntu重新分配磁盘空间-CSDN博客 …