llama 2 改进之 RMSNorm

news2024/12/30 3:05:30

RMSNorm
在这里插入图片描述

论文:https://openreview.net/pdf?id=SygkZ3MTJE
Github:https://github.com/bzhangGo/rmsnorm?tab=readme-ov-file
在这里插入图片描述
论文假设LayerNorm中的重新居中不变性是可有可无的,并提出了均方根层归一化(RMSNorm)。RMSNorm根据均方根(RMS)将一层神经元的总和输入正则化,得到模型重新缩放不变性特性和隐式学习率适应能力

LayerNorm 公式

深度学习当中,没有线性激活函数的预测公式

a i = ∑ j = 1 m w i j x j , y i = f ( a i + b i ) , \begin{aligned}a_i=\sum_{j=1}^mw_{ij}x_j,\quad y_i=f\left(a_i+b_i\right),\end{aligned} ai=j=1mwijxj,yi=f(ai+bi),

通过激活函数后,其中,随着前一层的更新,层的输入分布会发生变化。这可能会对参数梯度的稳定性产生负面影响,延迟模型收敛。为了减少这种转变,LayerNorm 对求和的输入进行归一化,以固定它们的均值和方差,如下所示:

a ˉ i = a i − μ σ g i , y i = f ( a ˉ i + b i ) , \begin{aligned}\bar{a}_i=\frac{a_i-\mu}{\sigma}g_i,\quad y_i=f\left(\bar{a}_i+b_i\right),\end{aligned} aˉi=σaiμgi,yi=f(aˉi+bi),

其中 a ˉ i \bar{a}_i aˉi是向量 a ˉ ∈ R n \bar{a}\in\mathbb{R}^n aˉRn的第 i i i个值,作为 α i \alpha_i αi的归一化替代值用于层激活。 g ∈ R n \mathbf{g}\in\mathbb{R}^n gRn是增益参数,用于重新调整标准化求和输入的大小,一开始设置为 1。 μ \mu μ σ 2 \sigma^2 σ2 分别是根据原始求和输入估计的均值和方差统计量。

μ = 1 n ∑ i = 1 n a i , σ = 1 n ∑ i = 1 n ( a i − μ ) 2 . \begin{aligned}\mu=\frac{1}{n}\sum_{i=1}^na_i,\quad\sigma=\sqrt{\frac{1}{n}\sum_{i=1}^n(a_i-\mu)^2}.\end{aligned} μ=n1i=1nai,σ=n1i=1n(aiμ)2 .

在本文中,假设重新缩放不变性是LayerNorm成功的原因,而不是重新定中心不变性。我们提出了RMSNorm,它只关注重新缩放不变性,并简单地根据均方根(RMS)统计对求和输入进行正则化:
a ˉ i = a i RMS ( a ) g i , where RMS ( a ) = 1 n ∑ i = 1 n a i 2 . \begin{aligned}\bar{a}_i=\frac{a_i}{\text{RMS}(\mathbf{a})}g_i,\quad\text{where RMS}(\mathbf{a})=\sqrt{\frac{1}{n}\sum_{i=1}^na_i^2}.\end{aligned} aˉi=RMS(a)aigi,where RMS(a)=n1i=1nai2 .

python实现

# root mean square layer normalization
def rln(x, s):
    _eps = 1e-5
    output = x / tensor.sqrt((x * x).mean(1)[:,None] + _eps)
    output = s[None, :] * output
    return output

# layer normalization
def ln(x, b, s):
    _eps = 1e-5
    output = (x - x.mean(1)[:,None]) / tensor.sqrt((x.var(1)[:,None] + _eps))
    output = s[None, :] * output + b[None,:]
    return output

使用pytorch来写RMSNorm的函数

import torch
import torch.nn as nn


class RMSNorm(nn.Module):
    def __init__(self, d, p=-1., eps=1e-8, bias=False):
        """
            Root Mean Square Layer Normalization
        :param d: model size
        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super(RMSNorm, self).__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias

        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)

        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)

    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size

        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)

        if self.bias:
            return self.scale * x_normed + self.offset

        return self.scale * x_normed

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942092.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DolphinScheduler安装教程

DolphinScheduler安装教程 前期准备工作 jdk 1.8mysql 5zookeeper 3.4.6hadoop 2.6psmisc yum -y install psmisc 解压安装包 # 将安装包apache-dolphinscheduler-2.0.8-bin.tar.gz放置/opt/download目录下 # 解压缩 tar -zxvf apache-dolphinscheduler-2.0.8-bin.tar.gz -C …

MYSQL 第二次作业

要求&#xff1a; mysgl>create database mydb8 worker; mysq> use mydb8 worker; mysgl>create table t worker( department id int(11)not null comment部门号, worker id int(11)primary key not null comment职工号, worker date date not null comment工作时间,…

使用wireshark第一次捕获数据包

打开wireshark&#xff1a; 点击捕获&#xff0c;选项。 这里我选择以太网&#xff0c;然后点开始&#xff1a; 然后就成这样了&#xff1a; 点击左上角那个红色的按钮&#xff0c;可以暂停捕获&#xff0c;就变成了下面的样子&#xff1a; 这三个框有自己的名字&…

Nginx 怎样处理请求的重试机制?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; 文章目录 Nginx 怎样处理请求的重试机制&#xff1f;一、为何需要重试机制&#xff1f;二、Nginx 中的重试机制原理三、Nginx 重试机制的配置参数四、Nginx 重试机制的实际…

【MySQL进阶之路 | 高级篇】范式概述与第一范式

1. 范式简介 在关系型数据库中&#xff0c;关于数据表的设计的基本原则&#xff0c;规则就称为范式。可以理解为&#xff0c;一张数据表的设计结果需要满足的某种设计标准的级别。要想设计一个结构合理的关系型数据库&#xff0c;必须满足一定的范式。 范式的英文名是Normal …

MongoDB教程(十八):MongoDB MapReduce

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; 文章目录 引言一、MapRed…

阿里云ubuntu宝塔面板部署uni-app-flask-websocket前后端项目

1.下载宝塔面板 wget -O install.sh https://download.bt.cn/install/install-ubuntu_6.0.sh && sudo bash install.sh ed8484bec 然后去安全组开放对应的端口 面板账户登录信息 【云服务器】请在安全组放行 29725 端口 进入控制面板后修改默认用户名和密码 2. …

Flask 框架 redirect() url_for()

url_for url_for 函数根据传入的端点名称&#xff08;即路由函数名&#xff09;生成对应的 URL。 1. url_for() url_for 函数根据传入的端点名称&#xff08;即路由函数名&#xff09;生成对应的 URL。 它接受一个或多个参数&#xff0c;其中第一个参数是路由的名称&#x…

antdesgin table 组件下载成excel

文章目录 发现宝藏一、需求二、报错 发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【宝藏入口】。 一、需求 原组件如下&#xff0c;需要添加下载功能 import React, { useState } from rea…

学习测试10-3自动化 web自动化

web自动化 chrome驱动下载地址&#xff1a; https://registry.npmmirror.com/binary.html?pathchromedriver/ https://googlechromelabs.github.io/chrome-for-testing/#stable观察Google版本&#xff0c;下相应的驱动 运行代码试试&#xff0c;成功Google就会弹出 from se…

记录|C#+winform创建扁平化风格界面

本项目的C#内容是自己跟做的&#xff0c;自己做的内容已经打包&#xff0c;可以通过自己跟做写的Dashboard界面&#xff0c;C#下的winform模式下载获得&#xff0c;但是需要花费3个积分 目录 前言一、左边设置和步骤界面步骤Step1.Step2.Step3.Step4Step5 二、右边属性和步骤属…

【PyTorch】基于LSTM网络的气温预测模型实现

假设CSV文件名为temperature_data.csv&#xff0c;其前五行和标题如下&#xff1a; 这里&#xff0c;我们只使用Temperature列进行单步预测。以下是整合的代码示例&#xff1a; import pandas as pd import numpy as np import torch import torch.nn as nn import torch.op…

【深度学习】yolov8-seg分割训练,拼接图的分割复原

文章目录 项目背景造数据训练 项目背景 在日常开发中&#xff0c;经常会遇到一些图片是由多个图片拼接来的&#xff0c;如下图就是三个图片横向拼接来的。是否可以利用yolov8-seg模型来识别出这张图片的三张子图区域呢&#xff0c;这是文本要做的事情。 造数据 假设拼接方式有…

Qt+OpenCascade开发笔记(一):occ的windows开发环境搭建(一):OpenCascade介绍、下载和安装过程

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/140604141 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…

OpenStack Yoga版安装笔记(八)glance练习补充2

1、openstack image list数据流回顾 OpenStack Yoga版安装笔记&#xff08;七&#xff09;通过Wireshark抓包、Mermaid绘图&#xff0c;解析了执行openstack image list的数据流&#xff0c;图示如下&#xff1a; 数据流1-4&#xff1a;user admin认证&#xff0c;并获得admin…

ros2--中间件--rmw

rmw robot middleware 什么是中间件 一套位于操作系统之上&#xff0c;引用程序之下的软件。 在ros2中理解就是&#xff1a;中间件就是介于某两个或者多个节点中间的组件 中间件的作用 就是提供多个节点中间通信用的。 教程 ROS2中间件DDS架构 ros2从入门到精通

使用puma部署ruby on rails的记录

之前写过一篇《记录一下我的Ruby On Rails的systemd服务脚本》的记录&#xff0c;现在补上一个比较政治正确的Ruby On Rails的生产环境部署记录。使用Puma部署项目。 创建文件 /usr/lib/systemd/system/puma.service [Unit] DescriptionPuma HTTP Server DocumentationRuby O…

在Linux、Windows和macOS上释放IP地址并重新获取新IP地址的方法

文章目录 LinuxWindowsmacOS 在Linux、Windows和macOS上释放IP地址并重新获取新IP地址的方法各有不同。以下是针对每种操作系统的详细步骤&#xff1a; Linux 使用DHCP客户端&#xff1a;大多数Linux发行版都使用DHCP&#xff08;动态主机配置协议&#xff09;来自动获取IP地址…

RT-Thread全球嵌入式电子设计大赛入选名单发布!

目录 概述 ​1 瑞萨 RA8D1 Vision Board 2 英飞凌 Psoc6-EvaluationKit-062S2 WIFI模块 3 恩智浦 FRDM-MCXN947 4 STM32 星火一号 STM32F407 5 先楫 HPM5300EVK (RISC-V) 6 自带开发板 概述 RT-Thread全球嵌入式电子设计大赛入选名单发布啦&#xff0c;如下名单的小…

数学建模学习(3)——模拟退火算法

一、模拟退火算法解TSP问题 import random import numpy as np from math import e, exp import matplotlib.pyplot as plt# 31个城市的坐标 city_loc [(1304, 2312), (3639, 1315), (4177, 2244), (3712, 1399), (3488, 1535),(3326, 1556), (3238, 1229), (4196, 1004), (4…