LLaMA详细解读

news2025/1/11 11:40:35

LLaMA 是目前为止,效果最好的开源 LLM 之一。精读 LLaMA 的论文及代码,可以很好的了解 LLM 的内部原理。本文对 LLaMA 论文进行了介绍,同时附上了关键部分的代码,并对代码做了注释。

摘要

LLaMA是一个系列模型,模型参数量从7B到65B。在大部分的任务上,LLaMA-13B强于GPT-3(175B)。LLaMA-65B的性能,可以和最好的LM相媲美,如Chinchilla-70B 和 PaLM-540B。

一、引言

一般而言,模型越大,效果越好。然而有文献指出[1],当给定计算量的预算之后,最好的performance,并不是最大的模型,而是在一个小模型上用更多的数据进行训练。针对给定的计算量预算,scaling laws可以计算如何选择数据量的大小和模型的大小。然而这忽略了inference的预算,而这一点在模型推理时非常关键。当给定一个模型performance目标之后,最好的模型不是训练最快的模型,而是推理最快的模型。尽管在这种情况下,训练一个更大的模型成本会更低。

文献[2]中推荐,训练一个 10B 的模型,需要 200B 的 tokens,而本文的实验发现,一个7B的模型,经过 1T tokens 训练之后,performance 仍然在增加。本文的目标在于,通过在超大规模的数据上训练,给出一系列可能最好 performance 的 LLM。

二、预训练数据

2.1 数据集

一共有1.4T的tokens,大部分的训练数据都只用了一次,除了Wikipedia 和 Books 使用了大概2个epochs。

Pre-training data

2.2 tokenizer

使用byte pair encoding (BPE) 算法,使用的是Sentence-Piece的实现。所有数字被拆分为单独的digit,所有未知的UTF-8 字符,回退到字节来进行分解。因此,LLaMA 可以通过byte 的方式,构造出很多不在 vocab 中的字符,从而也具有较好的多语言能力。

三、网络结构改进

使用了基于transformer的架构,并做了如下3点改进:

3.1 Pre-normalization

为了提高训练的稳定性,对每个transformer层的输入进行归一化,而不是输出进行归一化。

同时,使用 RMS Norm 归一化函数。RMS Norm 的全称为 Root Mean Square layer normalization。与 layer Norm 相比,RMS Norm的主要区别在于去掉了减去均值的部分,计算公式为:

RMS Norm 的作者认为这种模式在简化了Layer Norm 的计算,可以在减少约 7%∼64% 的计算时间[3]。

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        return (self.weight * hidden_states).to(input_dtype)

3.2 SwiGLU

使用SwiGLU替代了ReLU作为激活函数。和PaLM中不同,维度采用而不是 4𝑑 。

SwiGLU 在论文[4] 中提出,相比于其他的激活函数变体,可以取得 log-perplexity 的最优值(和 GEGLU 并列)。

GLU Variants Improve Transformer

SwiGLU 及几种类似变体的计算公式如下:

其中,。代码如下:

class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        # config 中 hidden_act = 'silu'
        # 'silu' 和 'swish' 对应的激活函数均为:SiLUActivation 
        # https://github.com/huggingface/transformers/blob/717dadc6f36be9f50abc66adfd918f9b0e6e3502/src/transformers/activations.py#L229
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        # 对应上述公式的 SwiGLU
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

从代码可以看到 LlamaMLP 中一共有 3 个 Linear 层,原因就在于 SwiGLU 激活函数比类似 ReLU 的激活函数,需要多一个 Linear 层进行门控。

3.3 RoPE

RoPE 的核心思想是“通过绝对位置编码的方式实现相对位置编码”,可以说是具备了绝对位置编码的方便性,同时可以表示不同 token 之间的相对位置关系。[5] 不同于原始 Transformers 论文中,将 pos embedding 和 token embedding 进行相加,RoPE 是将位置编码和 query (或者 key) 进行相乘。具体如下:

Rotary Position Embedding

其中,左侧的矩阵 𝑅𝑚 表示位置第 𝑚 个位置的位置编码,右侧的向量 𝑞𝑖 表示对应位置的 query 向量。两者相乘,即可得到增加了位置信息的 query (或者 key)。由于 𝑅𝑚 的稀疏性,上述矩阵乘法可以等价于:

Rotary Position Embedding 的简化实现

其中 ⊗ 是逐位对应相乘,

RoPE的代码实现如下[6]:

# 代码增加了注释,可以看到和原始公式的对应关系。
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        # 此处 inv_freq 对应公式中的 theta
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        # 此处 freqs 对应公式中的 m * theta, t 对应公式中的 m,表示位置
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 此处和原始公式不同,theta_0 和 theta_0 不再相邻
        # 而是分在向量的前半部分和后半部分
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
        # 大部分情况下,直接从这里返回
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    # 此次和原始推导中不同,正负号不是间隔的,而是分前半部分和后半部分。但对于结果没有影响
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    # 对应上图中 RoPE 的简化计算
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

四、高效实现

加速训练:

  • 使用了xformers库。
  • 减少了activation checkpointing 中,重新计算 activation 的计算量。手动实现 transformer 层的反向传递函数,保存了计算成本高的 activations,例如线性层的输出。
  • 通过使用 model parallelism 和 sequence parallelism 来减少显存的使用量。
  • 尽可能地将 activations 的计算和GPU之间的通讯进行并行。

加速效果:

  • 65B的模型,在2048个80G的A100 GPU上,可以达到380 tokens/sec/GPU的速度。训练1.4T tokens需要21天。

五、主要结果与结论

Massive Multitask LanguageUnderstanding

LLaMA-13B 优于 GPT-3,尽管只有1/10大小。 LLaMA-65B 是可以与 Chinchilla-70B 和 PaLM-540B 这种最佳的LLM相竞争的模型。经过微调之后,LLaMA的效果有显著的提升。

未来打算发布在更大的语料上预训练上的更大的模型,因为随着数据和模型的增大,可以看到 performance 的稳定提升。

优化器

LLaMA使用了AdamW优化器进行训练,优化器的超参数为 =0.9, =0.95

(关于AdamW这个大模型训练的优化器,可参考当前训练神经网络最快的方式:AdamW优化算法+超级收敛 | 机器之心[6])

下表为LLaMA不同参数大小模型的具体设置:

表2: LLaMA不同参数大小模型的具体设置

参数维度(dim)head个数layer层数学习率batch sizetoken数量
6.7B409632323.0e−44M1.0T
13.0B512040403.0e−44M1.0T
32.5B665652601.5e−44M1.4T
65.2B819264801.5e−44M1.4T

训练结果

如下图所示,7B、13B、33B和65模型的训练损失均呈下降趋势,且在所有token上训练完后,loss仍没有收敛的趋势。因此,在此时,增加训练的token数量,仍然可以使模型继续学习。

(LLaMA2就是在此结论的基础上,使用了更多的token进行训练)

020f808566e73586ea9239922bce9824.png

高效部署

研究团队做了一些优化来提高模型的训练速度:

  1. 因果多头注意的有效实现:使用因果多头注意的有效实现来减少内存使用和运行时间。该实现可在xformers库中获得,其灵感来自于固定激活值显存优化和FlashAttention。这是通过不存储注意力权重和不计算由于语言建模任务的因果性质而被掩盖的key/query分数来实现的。

  2. 激活重计算:为了进一步提高训练效率,通过检查点减少了在向后传递过程中重新计算的激活量。更准确地说,节省了计算成本高的激活,比如线性层的输出。这是通过手动实现transformer层的backward函数来实现的,而不是依赖于PyTorch的autograd。

  3. 模型并行和序列并行:为了从这种优化中充分受益,需要通过使用模型和序列并行来减少模型的内存使用。此外,还尽可能地重叠激活的计算和gpu之间通过网络的通信。

笔者NOTE:LLM的高效训练是LLM工程实现的基础,对于这部分,各位小伙伴还是需要深入地了解一下各种并行策略、因果多头注意的有效实现、 激活重计算、混合精度训练。

参考

  1. ^Training Compute-Optimal Large Language Models https://arxiv.org/abs/2203.15556
  2. ^Training Compute-Optimal Large Language Models https://arxiv.org/abs/2203.15556
  3. ^Root Mean Square Layer Normalization https://arxiv.org/pdf/1910.07467.pdf
  4. ^GLU Variants Improve Transformer https://arxiv.org/pdf/2002.05202.pdf
  5. ^Transformer升级之路:2、博采众长的旋转式位置编码 Transformer升级之路:2、博采众长的旋转式位置编码 - 科学空间|Scientific Spaces
  6. ^transformers/src/transformers/models/llama/modeling_llama.py at main · huggingface/transformers · GitHub

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1641971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

u盘格式化后电脑读不出来怎么办?u盘格式化的东西还能恢复吗

随着科技的快速发展,U盘已成为我们日常生活和工作中不可或缺的数据存储工具。然而,有时我们可能会遇到U盘格式化后电脑无法读取的情况,或是误格式化导致重要数据丢失。面对这些问题,我们该如何应对?本文将为您详细解答…

python邮件发送

第一种方式 一:发送的邮件要设置授权码,通过邮箱邮箱授权码去验证,让邮件服务器帮我们去转发邮件到要接收的邮件,代码中的授权码,是需要登录126邮箱(我这里是以126邮件发送的,具体的以自己为准…

概念解析 | 互补学习系统

注1:本文系"概念解析"系列之一,致力于简洁清晰地解释、辨析复杂而专业的概念。本次辨析的概念是:互补学习系统(Complementary Learning Systems) 概念解析:互补学习系统 Paper Summary - “Complementary Learning Systems Theory Updated” | Rylan Schaeffer…

数据库MySQL的基本操作

在Linux里面,我们要对数据库MySQL进行操作时(例如修改MySQL的密码),不是直接在我们的终端上进行操作,而是通过终端连接进入到MySQL里面去,在进行操作,写SQL语句。 而安装C等的开发库sudo命令&a…

Crocoddyl 使用教程(二)

系列文章目录 前言 小车摆杆是另一个经典的控制实例。在这个系统中,一根欠驱动的杆子被固定在一辆一维驱动的小车顶部。游戏的目的是将杆子升到站立位置。 模型如下: https://en.wikipedia.org/wiki/Inverted_pendulum 我们用 表示小车质量、 表示摆杆质…

Visual studio调试技巧

Visual studio调试技巧 bug是什么?Debug和ReleaseDebugRelease 如何调试VS调试快捷键调试过程中查看程序信息查看临时变量的值查看内存信息查看调用堆栈查看汇编信息查看寄存器信息 编译常见错误编译型错误链接型错误运行时错误 bug是什么? bug的英文释…

机器学习笔记-22

终章 至此吴恩达老师的机器学习课程已经完成啦,总结一下: 1.监督学习的算法:线性回归、逻辑回归、神经网络和向量机 2.无监督学习的算法:K-Means、PCA、异常检测 3.推荐系统、大规模数据处理、正则化、如何评估算法 4.上限分析、…

Servlet_JSP

1.一些回顾 对于Tomcat部署中 我们有一些补充的点需要在此说明一下 1.如果我们想要查询MINEType的话 可以到TOMCAT_HOME/conf/web.xml中进行查询 里面记录了不同类型对应的MINEType 2.我们客户端发送请求数据给服务器之后 服务器会调用父类中的service方法 然后在内部决定调用…

用Jenkins Gerrit-Trigger插件实现提交gerrit后自动启动编译验证-解决编译依赖问题

用Jenkins Gerrit-Trigger插件实现提交gerrit后自动启动编译验证-CSDN博客讨论了如何利用插件在提交gerrit的时候自动出发一个jenkins job编译固件,但是没有解决编译依赖问题。本文提出一种解决方案 首先在git commit -m ""的时候在commit message中设置Depend-On:…

ControlNet官方资源链接【ControlNet论文原文】【持续更新中~】

ControlNet官方资源链接 ControlNet论文原文:https://arxiv.org/abs/2302.05543ControlNet官方GitHub:https://github.com/lllyasviel/ControlNetControlNet 1.1官方GitHub:https://github.com/lllyasviel/ControlNet-v1-1-nightlyControlNe…

深度学习之基于Vgg16卷积神经网络印度交警手势识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景 随着智能交通系统的不断发展,手势识别技术在其中扮演着越来越重要的角色。特别是在印度等…

CVE-2017-11882分析和白象样本分析

CVE-2017-11882分析和白象样本分析 CVE-2017-11882是微软公布的一个远程代码执行漏洞,漏洞是由模块EQNEDT32.EXE公式编辑器引起,该模块在Office的安装过程中被默认安装,该模块以OLE技术(Object Linking and Embedding&#xff0c…

《网络安全---frida应用实践---某付费视频应用一举拿下》

文章目录 目标应用环境:步骤1、查壳2、定位付费界面布局3、找到可疑方法4、那就看下请求信息吧,看下有没有思路5、其他请求(列表,视频信息,获取播放url)6、请求参数加密算法7、图片信息解密8、数据请求关键点9、以上都是废话10、直接找关键hook点总结相关源码1、文章仅供…

2.初探MPI——点对点通信(阻塞)

系列文章目录 初探MPI——MPI简介初探MPI——(阻塞)点对点通信初探MPI——(非阻塞)点对点通信初探MPI——集体通信 文章目录 系列文章目录前言一、Sending & Receiving message1.1 简介1.2 发送消息1.3 接收消息1.4 MPI 发送…

AI智能名片商城小程序构建企业级私域的IMC模型:IP、MarTech与Content的深度融合

在数字化营销的新时代,为企业定制开发的AI智能名片B2B2C商城小程序,结合我们丰富的私域运营实践,我们深刻领悟到构建企业级私域的三大核心要素:IP(企业人设)、MarTech(营销技术)和Co…

【自动化测试】使用MeterSphere进行接口测试

一、接口介绍二、接口测试的过程三、接口自动化测试执行自动化流程 四、接口之间的协议HTTP协议 五、 接口测试用例设计接口文档 六、使用MeterSphere创建接口测试创建接口定义设计接口测试用例 一、接口介绍 自动化测试按对象分为:单元测试、接口测试、UI测试等。…

C语言/数据结构——每日一题(移除链表元素)

一.前言 今天在leetcode刷到了一道关于单链表的题。想着和大家分享一下。废话不多说,让我们开始今天的知识分享吧。 二.正文 1.1题目要求 1.2思路剖析 我们可以创建一个新的单链表,然后通过对原单链表的遍历,将数据不等于val的节点移到新…

【Java从入门到精通】Java 流(Stream)、文件(File)和IO

Java.io 包几乎包含了所有操作输入、输出需要的类。所有这些流类代表了输入源和输出目标。 Java.io 包中的流支持很多种格式,比如:基本类型、对象、本地化字符集等等。 一个流可以理解为一个数据的序列。输入流表示从一个源读取数据,输出流…

获取淘宝商品销量数据接口

淘宝爬虫商品销量数据采集通常涉及以下几个步骤: 1、确定采集目标:需要明确要采集的商品类别、筛选条件(如天猫、价格区间)、销量和金额等数据。例如,如果您想了解“小鱼零食”的销量和金额,您需要设定好价…

扫雷实现详解【递归展开+首次必展开+标记雷+取消标记雷】

扫雷 一.扫雷设计思路二.扫雷代码逐步实现1.创建游戏菜单2.初始化棋盘3.打印棋盘4.随机布置雷5.统计周围雷的个数6.递归展开棋盘7.标记雷8.删除雷的标记9.保证第一次排雷的安全性棋盘必定展开10.排查雷11.判断输赢 三.扫雷总代码四.截图 一.扫雷设计思路 1.创建游戏菜单。  2.…