【LLM系列之LLaMA】LLaMA: Open and Efficient Foundation Language Models

news2024/11/28 7:38:26
论文题目:《LLaMA: Open and Efficient Foundation Language Models》
论文链接:https://arxiv.org/pdf/2302.13971.pdf
github链接:https://github.com/facebookresearch/llama/tree/main
huggingface链接:https://huggingface.co/decapoda-research/llama-7b-hf

1 模型简介

LLaMA 是 Meta AI 发布的包含 7B、13B、33B 和 65B 四种参数规模的基础语言模型集合,LLaMA-13B 仅以 1/10 规模的参数在多数的 benchmarks 上性能优于 GPT-3(175B),LLaMA-65B 与业内最好的模型 Chinchilla-70B 和 PaLM-540B 比较也具有竞争力。

主要贡献:

  • 开源一系列语言模型,可以与SOTA模型竞争
  • LLaMA-13B比GPT-3的性能更好,但是模型大小却是十分之一
  • LLaMA-65B与Chinchilla-70B和PaLM-540B的实力相当
  • 使用公开数据集即可部分复现最先进的性能(86%左右的效果)

2 研究背景

在给定预算的条件下,最好的模型并不一定是最大的模型,在更多的数据上训练的较小的模型反而会达到更好的性能。Hoffmann工作的目的是决定如何确定数据集和模型大小的规模,但是他忽略了推理的成本。所以在这篇文章中,给定一个目标的性能等级,更推荐的模型不是最快训练的,但是是最快推理的。产生的模型称为LLaMA,参数范围从7B到65B,与现在最好的LLM相当。

LLaMA-13B比GPT-3在大多数benchmarks上性能都要好,但是模型大小缩减到十分之一。Meta团队相信这个模型有助于LLM的使用和研究的大众化,因为可以在单个GPU上运行。在更高的规模量上,65B参数量模型与当前最好的LLM(比如Chinchila或PaLM-540B)相比更具有竞争力。LLaMA的另一个优势是它是使用公开数据集进行训练。

3 训练方法

这项工作的训练方法相似于Brown的工作,并且受到Hoffmann(Chinchilla scaling laws)的启发。模型使用标准优化器进行优化。后面会单独解读下《 Scaling Laws for Neural Language Models》,该文主要建模了模型性能与非embedding参数 N,数据集大小 D 与计算量 C之间的关系。最主要的发现:

  • 性能主要与模型大小相关,而与模型结构弱相关
  • 性能与上面三个因素有比较贴合的power-law关系

从实验来看,模型越大越好,小模型确实达不到大模型大力出奇迹的效果,而模型结构也并没有那么重要(虽然有很多工作是在改进模型结构本身)。结论部分更强调了大模型比大数据更重要

3.1 预训练数据

我们的训练数据集是多个来源的混合,如表 1 所示,涵盖了不同的领域。 在大多数情况下,我们重复使用已用于培训其他 LLM 的数据源,但仅限于使用公开可用且与开源兼容的数据。 这导致以下混合数据及其在训练集中所占的百分比:

3.2 模型结构

整体架构仍然是Transformer的解码器模块,该模块参考论文Attention is all you need。下面是在Transformer架构上的进一步的3个改进。

  • [GPT3] 使用RMSNorm(即Root Mean square Layer Normalization)对输入数据进行标准化,RMSNorm可以参考论文:Root mean square layer normalization。
    a ˉ i = a i RMS ( a ) g i , where  RMS ( a ) = 1 n ∑ i = 1 n a i 2 . \begin{align} \begin{split} & \bar{a}i = \frac{a_i}{\text{RMS}(\mathbf{a})} g_i, \quad \text{where}~~ \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n} \sum{i=1}^{n} a_i^2}. \end{split}\nonumber \end{align} aˉi=RMS(a)aigi,where  RMS(a)=n1i=1nai2 .

为了提高训练的稳定性,在每个transformer子层的input处进行正则化,而不是在output处,使用的正则化方法是RMSNorm。
LLaMA源码中实现方式为:

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

  • [PaLM]使用激活函数SwiGLU, 该函数可以参考PALM论文:Glu variants improve transformer。
class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
  • [GPTNeo]使用Rotary Embeddings进行位置编码,该编码可以参考论文 Roformer: Enhanced transformer with rotary position embedding。

3.3 优化器

使用了AdamW优化器,并使用cosine learning rate schedule,使得最终学习率等于最大学习率的10%,设置0.1的权重衰减和1.0的梯度裁剪。warmup的step为2000,并根据模型的大小改变学习率和批处理大小(详见表2)。

3.4 高效实现

  • 作者做了一些优化来提高模型的训练速度。首先,使用因果多头注意的有效实现来减少内存使用和运行时间。该实现可在xformers库中获得。

https://github.com/facebookresearch/xformers

  • 为了进一步提高训练效率,通过检查点减少了在向后传递过程中重新计算的激活量。更准确地说,节省了计算成本高的激活,比如线性层的输出。这是通过手动实现transformer层的backward函数来实现的,而不是依赖于PyTorch的autograd。

这个指的是gradient checkpointing,这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)。

  • 此外,还尽可能地覆盖激活的计算和gpu之间通过网络的通信(由于all_reduce操作)。

  • 训练65b参数模型时,代码在2048 A100 GPU和80GB RAM上处理大约380个token/秒/GPU。这意味着在包含1.4T token的数据集上进行训练大约需要21天。

4 实验结果

作者主要对比了在Zero-shot、Few-shot上的结果。

4.1 常识推理(Common Sense Reasoning)


可以观察到13B和GPT-3 175B的结果实际上是非常相近的。

4.2 闭卷问答(Closed-book QA)



LLaMA模型要优于PaLM 540B模型

4.3 阅读理解(Reading Comprehension)


可以看到LLaMA对标到540B的PaLM。

4.4 数学推理(Mathematical reasoning)

4.5 代码生成(Code generation)

4.6 大规模多任务语言理解(Massive Multitask Language Understanding)


可以观察到LLaMA-65B在大多数领域平均落后于Chinchilla70B和PaLM-540B几个百分点。一种可能的解释是,预训练数据中使用了有限数量的书籍和学术论文,即ArXiv, Gutenberg和book3,总计只有177GB,而这些模型在高达2TB的书籍上进行了训练。Gopher、Chinchilla和PaLM使用的大量书籍可能也解释了为什么Gopher在这个基准测试中表现优于GPT-3,而在其他基准测试中却不相上下

4.7 训练过程中的性能演变(Evolution of performance during training)


在训练期间,我们在一些问题回答和常识基准上跟踪了模型的性能,并在图2中报告了它们。在大多数基准测试中,性能稳步提高,并且与模型的训练困惑度相关(见图1)。例外是SIQA和WinoGrande。最值得注意的是,在SIQA上,我们观察到性能上有很多差异,这可能表明这个基准测试不可靠。在WinoGrande上,性能与训练困惑度不相关:LLaMA-33B和LLaMA-65B在训练过程中表现相似。

5 指令调优


指令模型LLaMA-I在MMLU上的结果,并与现有中等规模的指令微调模型,进行了比较。尽管这里使用的指令调优方法很简单,但在MMLU上达到了68.9%。LLaMA-I (65B)在现有中等规模的指令微调模型上的表现优于MMLU,但仍远未达到最先进的水平,在MMLU上的GPT code-davincii-002为77.4。

6 模型代码

https://github.com/facebookresearch/llama/blob/main/llama/model.py

class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

作者对每个Transformer子层的输入进行归一化,而不是对输出进行归一化。注意看Transformer中黄色方块(Add & Norm)部分,都是在输出部分的,现在把这个操作调整到前面对输入进行Norm操作。

class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = ParallelEmbedding(
            params.vocab_size, params.dim, init_method=lambda x: x
        )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
        )

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)
        output = self.output(h[:, -1, :])  # only compute last logits
        return output.float()

7 论文结论

本文中提出了一系列公开发布的语言模型,并实现与最先进的基础模型相竞争的结果。最值得注意的是,LLaMA-13B的性能优于GPT-3,但体积比GPT-3小10倍以上,LLaMA-65B与Chinchilla-70B和PaLM-540B竞争。

与之前的研究不同,论文的研究表明,不使用专有数据集,而只使用公开可用的数据集进行训练,可以达到最先进的性能。作者希望向研究界发布这些模型将加速大型语言模型的发展,并有助于提高它们的鲁棒性,减轻已知的问题,如毒性和偏见。

此外,作者像Chung等人一样观察到,根据指令对这些模型进行微调会产生有希望的结果计划在未来的工作中进一步研究这一点。

最后,作者计划在未来发布在更大的预训练语料库上训练的更大的模型,因为作者在扩展语料时已经看到了性能的不断提高

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/528732.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[离散数学] 函数

文章目录 函数判断函数的条件复合函数复合函数的性质 逆函数 函数 判断函数的条件 dom F A ⇔ \Leftrightarrow ⇔所有x 都有 F&#xff08;x&#xff09;与之对应 有唯一的与其对应 < x , y > ∈ f ∧ < y , z > ∈ f ⇒ y z <x,y>\in f \land <y,z…

【C++】2. 进入面向对象 - 类和对象的初步认识

专栏导读 &#x1f341;作者简介&#xff1a;余悸&#xff0c;在读本科生一枚&#xff0c;致力于 C 方向学习。 &#x1f341;收录于 C专栏&#xff0c;本专栏主要内容为 C 初阶、C 进阶、STL 详解等&#xff0c;持续更新中&#xff01; &#x1f341;相关专栏推荐&#xff1a;…

TikTok新手做什么账号,Tiktok类目怎么选

有很多刚入驻TikTok的小白不知道要选择什么类目才比较容易起量。看完这篇后&#xff0c;相信你们的疑惑就会烟消云散。选择对了类目对以后的产品带货有很大的促进作用&#xff0c;今天我给你们分享6种适合TikTok小店运营的账号类型&#xff0c;以及一些比较推荐的类目。 TikTok…

测试和调试之Python高级篇

测试和调试 在软件开发过程中&#xff0c;测试和调试是非常重要的环节。测试用于验证代码的正确性和可靠性&#xff0c;而调试则是为了找到并解决代码中存在的问题。下面将会详细介绍单元测试、集成测试、断言、测试框架、调试工具和技巧。 单元测试 单元测试是指对软件中的…

linux环境安装使用redis详解

Redis 1. NoSQL的引言 NoSQL ( Not Only SQL )&#xff0c;意即 不仅仅是SQL , 泛指非关系型的数据库。Nosql这个技术门类,早期就有人提出,发展至2009年趋势越发高涨。 2. 为什么是NoSQL 随着互联网网站的兴起&#xff0c;传统的关系数据库在应付动态网站&#xff0c;特别是超大…

OpenPCDet系列 | 5.1 PointPillars算法——PillarVFE特征构建与编码模块

文章目录 PillarVFE模块1. PillarVFE初始化2. PillarVFE数据处理2.1 特征构造2.2 掩码构造2.3 特征编码 OpenPCDet的整个结构图&#xff1a; PillarVFE模块属于VFE结构的其中一种&#xff0c;所以可以在PCDet中的backbone_3d目录下&#xff0c;可以找到vfe目录结构。在OpenPCDe…

【JOSEF约瑟 JL-8GA/12端子排电流继电器 整定范围宽、功耗低】

JL-8GA/12端子排电流继电器名称:端子排电流继电器型号:JL-8GA/12品牌:JOSEF约瑟功率消耗:≤5W触点容量:250V5A额定电压:58,100,110,220V 系列型号&#xff1a; JL-8GA/11端子排电流继电器&#xff1b; JL-8GA/12端子排电流继电器&#xff1b; JL-8GA/13端子排电流继电器&am…

MySQL(1) ---- 数据库介绍与MySQL概述

介绍 1、什么是数据库&#xff1f; 数据库&#xff1a;DateBase&#xff08;DB&#xff09;&#xff0c;是存储和管理数据的仓库。数据库管理系统&#xff1a;DataBase Management System&#xff08;DBMS&#xff09;&#xff0c;操纵和管理数据库的大型软件。SQL&#xff1…

【C语言】手把手教你文件操作

文章目录 一、前言二、文件的打开和关闭1. fopen函数2. fclose函数 三、文件的顺序读写四、文件的随机读写1. fseek函数2. ftell函数3. fwind函数 一、前言 程序运行时&#xff0c;数据存放在内存中&#xff0c;而当程序退出后&#xff0c;数据也就不复存在。 想做到数据持久化…

数据库管理-第七十五期 手把手教你搭19c RAC(20230516)

数据库管理 2023-05-16 第七十五期 手把手教你搭19c RAC1 基础环境2 操作系统配置2.1 /etc/hosts2.2 配置系统挂载2.3 配置本地yum源2.4 操作系统配置2.5 安装预安装RPM包并配置&#xff1a;2.6 创建对应目录2.7 配置时间同步 3 存储挂载3.1 存储环境3.2 存储识别3.3 多路径聚合…

生成一个手绘图为底图的导游图

1 前言 上一篇演示了制作一个简版导游图。简版导游图的优点是制作简单、快速&#xff0c;不需要第三方软件&#xff0c;缺点是略显简陋、不够专业。 本编介绍制作专业导游图的步骤&#xff0c;用手绘图为地图&#xff0c;用图形展现景区信息&#xff0c;能表现出丰富的景区细…

ChatGPT:使用Edge浏览器获取ChatGPT以及如何使用ChatGPT帮你制作PPT

一&#xff1a;前言 ChatGPT&#xff1a;智能AI助你畅聊天地 在现代人日益忙碌的生活中&#xff0c;难免需要一些轻松愉快的聊天来放松身心。而现在&#xff0c;有了 ChatGPT&#xff0c;轻松愉快的聊天变得更加智能、有趣且不受时间、地点限制&#xff01; 什么是 ChatGPT&…

NSSCTF-[深育杯 2021]Press

下载链接&#xff1a;下载 载入IDA&#xff0c;查看内容 首先进入一个函数进行初始化&#xff0c;进入查看 unsigned __int64 sub_4007B6() {int v1; // [rsp8h] [rbp-48h]int i; // [rspCh] [rbp-44h]char src[56]; // [rsp10h] [rbp-40h] BYREFunsigned __int64 v4; // [r…

【可乐荐书】有趣的矩阵:看得懂又好看的线性代数

本栏目将推荐一些经典的、有趣的、有启发性的书籍&#xff0c;这些书籍涵盖了各个领域&#xff0c;包括文学、历史、哲学、科学、技术等等。相信这些书籍不仅可以让你获得知识&#xff0c;还可以让你感受到阅读的乐趣和魅力。 今天给大家推荐的书籍是&#xff1a;《有趣的矩阵…

【简单DP】CF1420 C1

昨天的CF心态又打崩了 好久没写DP了这道题一发过了 但是大家都会qwq 烦死 Problem - C1 - Codeforces 题意&#xff1a; 给定一个序列&#xff0c;让你找出一个子序列 使得 这个最大&#xff0c;a是子序列 思路&#xff1a; 首先子序列&#xff0c;自然就是DP 然后每个…

品牌活动如何策划,更利于传播?(吸引媒体报道)

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 之前做媒体的时候&#xff0c;参加过无数的媒体活动&#xff0c;现在做媒体传播也给了许多品牌一些建议&#xff0c;有的活动设计的很有趣&#xff0c;有的活动设计的很巧妙&#xff0c;…

响应式设计 MediaQuery和flex

一、MediaQuery(媒体查询)的概念 为不同尺寸的屏幕设定不同的css样式 示例 二、media常用参数 三、媒体查询代码示例 MediaQuery在浏览器中的显示示例 MediaQuery综合案例 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8…

Go语言笔记:使用ssh包作为客户端与SSH服务器交互

文章目录 目的基础说明使用演示单次通讯连续通讯&#xff08;远程终端&#xff09; 总结 目的 Golang中可以使用 golang.org/x/crypto/ssh 包作为SSH客户端或者SSH服务使用。这篇文章将简单记录下作为客户端使用的一些内容。 Package ssh implements an SSH client and server…

QT自定义控件折线图、趋势图。

这里提供两种实现方式&#xff0c;一直自绘的自定义控件&#xff0c;一直三方SDK&#xff08;qcustomplot&#xff09;。 这里主要介绍自绘的&#xff0c;它的优点是结构简单&#xff0c;代码逻辑好修改&#xff0c;容易定制&#xff0c;缺点是功能相对单一。三方的qcustomplot…

循迹模块(应用于小车)

1.1循迹模块使用 TCRT5000传感器的红外发射二极管不断发射红外线 当发射出的红外线没有被反射回来或被反射回来但强度不够大时&#xff0c; 红外接收管一直处于关断状态&#xff0c;此时模块的输出端为高电平&#xff0c;指示二极管一直处于熄灭状态 被检测物体出现在检测范…