【LLM系列之PaLM】PaLM: Scaling Language Modeling with Pathways

news2025/2/24 3:03:59
论文题目:《Scaling Instruction-Finetuned Language Models》
论文链接:https://arxiv.org/abs/2204.02311
github链接1:https://github.com/lucidrains/PaLM-pytorch/tree/main;
github链接2:https://github.com/conceptofmind/PaLM
huggingface链接:https://huggingface.co/conceptofmind/palm-1b

1 主要贡献

  • 提出了 Pathways Language Model (PaLM),这是一个 5400 亿参数、密集激活的 Transformer 语言模型。
  • PaLM 使用 Pathways 在 6144 TPU v4 芯片上进行训练,Pathways 是一种新的 ML 系统,可以跨多个 TPU Pod 进行高效训练。
  • 它通过在数百种语言理解和生成基准上实现小样本学习sota结果,证明了scaling的良好效果。

2 PaLM模型

在这里插入图片描述

2.1 模型结构

PaLM 在decoder-only架构中使用标准的 Transformer 模型架构(即每个时间步只能关注其自身和过去的时间步),并进行以下修改:
(1)采用SwiGLU激活函数:用于 MLP 中间激活,因为与标准 ReLU、GELU 或 Swish 激活相比,《GLU Variants Improve Transformer》论文里提到:SwiGLU 已被证明可以显著提高模型效果。

我们回顾下上面提到的激活函数:
ReLU激活函数:
R e L U ( x ) = m a x ( 0 , x ) ReLU(x)=max(0,x) ReLU(x)=max(0,x)

GeLU激活函数:
G e L U ( x ) = x Φ ( x ) = x ∫ − ∞ x 1 2 π e − t 2 2 d t = x ⋅ 1 2 [ 1 + e r f ( x 2 ) ] GeLU(x)=x\Phi(x)=x\int_{-\infty}^{x}\frac{1}{\sqrt{2\pi}}e^{-\frac{t^{2}}{2}}dt=x\cdot \frac{1}{2}[1+erf(\frac{x}{\sqrt{2}})] GeLU(x)=xΦ(x)=xx2π 1e2t2dt=x21[1+erf(2 x)]

其中erf为误差函数。

Swish激活函数:
S w i s h = x ⋅ s i g m o i d ( β x ) Swish=x\cdot sigmoid(\beta x) Swish=xsigmoid(βx)

我们不难发现,激活函数就是对x乘以一些数,以对某些值进行约束。

G L U ( x ) = σ ( W x + b ) ⊗ ( V x + c ) GLU(x)=\sigma (Wx+b) \otimes (Vx+c) GLU(x)=σ(Wx+b)(Vx+c)

三种 GLU 变体如下:

SwiGLU实现如下:

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

按照下面的方法用于FNN:

(2)提出Parallel Layers:每个 Transformer 结构中的“并行”公式:与 GPT-J-6B 中一样,使用的是标准“序列化”公式。具体来说,标准公式可以写成:

并行公式可以写成:

并行公式使大规模训练速度提高了大约 15%。消融实验显示在 8B 参数量下模型效果下降很小,但在 62B 参数量下没有模型效果下降的现象。

# parallel attention and feedforward with residual
# discovered by Wang et al + EleutherAI from GPT-J fame


class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        self.heads = heads
        self.scale = dim_head**-0.5
        self.rotary_emb = RotaryEmbedding(dim_head)

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        # for caching causal mask and rotary embeddings

        self.register_buffer("mask", None, persistent=False)
        self.register_buffer("pos_emb", None, persistent=False)

    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.register_buffer("mask", mask, persistent=False)
        return mask

    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.register_buffer("pos_emb", pos_emb, persistent=False)
        return pos_emb

    def forward(self, x):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        n, device, h = x.shape[1], x.device, self.heads

        # pre layernorm

        x = self.norm(x)

        # attention queries, keys, values, and feedforward inner

        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # split heads
        # they use multi-query single-key-value attention, yet another Noam Shazeer paper
        # they found no performance loss past a certain scale, and more efficient decoding obviously
        # https://arxiv.org/abs/1911.02150

        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # rotary embeddings

        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # scale

        q = q * self.scale

        # similarity

        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # causal mask

        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # attention

        attn = sim.softmax(dim=-1)

        # aggregate values

        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # merge heads

        out = rearrange(out, "b h n d -> b n (h d)")
        return self.attn_out(out) + self.ff_out(ff)

(3)Multi-Query Attention:每个头共享键/值的映射,即“key”和“value”被投影到 [1, h],但“query”仍被投影到形状 [k, h],这种操作对模型质量和训练速度没有影响,但在自回归解码时间上有效节省了成本。
(4) 使用RoPE embeddings:使用的不是绝对或相对位置嵌入,而是RoPE,是因为 RoPE 嵌入在长文本上具有更好的性能 ,具体原理可看苏神文章《Transformer升级之路:2、博采众长的旋转式位置编码》

# rotary positional embedding
# https://arxiv.org/abs/2104.09864
class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)


def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())

(5) 采用Shared Input-Output Embeddings:输入和输出embedding矩阵是共享的,这个我理解类似于word2vec的输入W和输出W’:

上图来自《Using the Output Embedding to Improve Language Models》
(6)不使用偏置项:在dense kernel或layer norm中都没有使用偏差,这种操作提高了大模型的训练稳定性。
(7) 词汇表:使用具有 256k 标记的 SentencePiece 词汇表,选择它来支持训练语料库中的大量语言,而无需过度标记化。

2.2 模型变体

在这里插入图片描述
考虑了三种不同的模型尺度:540B、62B 和 8B 参数。

2.3 训练数据

  • PaLM 预训练数据集包含一个包含 7800 亿个标记的高质量语料库,代表了广泛的自然语言用例。 该数据集是经过过滤的网页、书籍、维基百科、新闻文章、源代码和社交媒体对话的混合体。 该数据集基于用于训练 LaMDA(Thoppilan 等人,2022 年)和 GLaM(Du 等人,2021 年)的数据集。
  • 所有三个模型都只在一个时期的数据上进行训练(所有模型的数据清洗方式都相同)。
  • 除了自然语言数据,预训练数据集还包含 196GB 代码,从 GitHub 上的开源存储库获取,包括 Java、HTML、Javascript、Python、PHP、C#、XML、C++ 和 C。

最终的 PaLM 数据集混合如上表所示

2.4 训练硬件资源


总体来说,该程序包含用于 pod 内前向+反向计算(包括 pod 内梯度减少)的组件 A,用于跨 pod 梯度传输的传输子图,以及用于优化器更新的组件 B(包括本地和远程梯度的求和) ).

Pathways 程序在每个 pod 上执行组件 A,然后将输出梯度传输到另一个 pod,最后在每个 pod 上执行组件 B。

因此,它掩盖了延迟。 此外,它还分摊了管理数据传输的成本。

作者还详细提到了实际设置,例如两个 pod 之间的主机通过 Google 数据中心网络连接。 (感兴趣的请直接阅读论文。)

这块理解为 他们在TPU训练架构,不是单纯的多机多GPU,反正没TPU可以用看了。。

PaLM 代表了 LLM 训练效率向前迈出的重要一步。

2. 英语NLP任务效果

PaLM 模型在与 Du 等人相同的一组 29 个英语基准上进行评估。 (2021) 和布朗等人。 (2020)。

PaLM 540B 在 1-shot 设置的 29 个任务中的 24 个和在 few-shot 设置的 29 个任务中的 28 个任务上优于之前的 SOTA。 有趣的是,PaLM 540B 在一些阅读理解和 NLI 任务的小样本设置中比之前的 SOTA 高出 10 多分。

PaLM 540B 在所有基准测试中都优于类似尺寸的模型(Megatron-Turing NLG 530B)。 这表明预训练数据集、训练策略和训练期间观察到的标记数量在实现这些结果方面也起着重要作用。

3 BIG-Bench 效果

在 58 项任务中,PaLM 的表现明显优于 GPT-3、Gopher 和 Chinchilla,并且 5-shot PaLM 540B 的得分高于要求解决相同任务的人类的平均得分。


5-shot PaLM 540B 在 58 个常见任务中的 44 个上优于之前的 SOTA,每个任务的结果如上所示

4 逻辑推理效果


推理任务是需要多步算术或常识性逻辑推理才能产生正确答案的任务。

PaLM 540B 实现了 58% 的性能,优于 Cobbe 等人之前 55% 的 SOTA。

5 代码生成效果

来自 PaLM-Coder 540B 型号的示例。 (左上)从 OpenAI GSM8K 数学数据集转换而来的 GSM8K-Python 问题。 (左下)将一个简单函数从 C++ 转换为 Python 的 TransCoder 示例。 (右)转换后的 HumanEval 示例。

上面显示了代码任务数据集的一些示例。

PaLM-Coder 是 PaLM,具有 2 个阶段的代码进一步微调。

PaLM-Coder 540B 的性能进一步提高,在 HumanEval 上达到 88.4% pass@100,在 MBPP 上达到 80.8% pass@80。

6 翻译效果

左图:PaLM 优于所有基线,有时非常果断,差异高达 13 BLEU。 右图:将 PaLM 从 62B 缩放到 540B 会导致 BLEU 分数出现几次急剧跳跃,这不符合“幂律”经验法则。

7 其他

还有其他结果,例如:多语言自然语言生成,多语言问答。
此外,还讨论了其他问题,例如:记忆、数据集污染、偏见、伦理问题、未决问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/522703.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每天一个提高效率的Matlab编程小技巧(1)-dbstop if error

相信在matlab调试程序的时候都遇到过这种情况:运行程序时命令行报错,而且出错的位置在我们自己定义的函数里,比如下面这个例子: 主函数main.m: a[1 2 3]; b[4 5]; csum_squares(a,b); 子函数sum_squares.m function csum_squa…

AI孙燕姿 ?AI东雪莲 !—— 本地部署DDSP-SVC一键包,智能音频切片,本地训练,模型推理,为你喜欢的角色训练AI语音模型小教程

目录 感谢B站UP羽毛布团 演示视频 稻香——东雪莲 虚拟——东雪莲 反方向的钟——东雪莲 晴天龙卷风——东雪莲 DDSP-SVC 3.0 (D3SP) 是什么? 下载资源: 解压整合包 准备数据集 智能音频切片 数据集准备 填写训练设置和超参数 开始训练 推…

这个抓包工具太强了,科来网络分析系统强烈推荐

一直以来抓包工具,都推荐和使用wireshark,简单好用。最近发现一款更强大好用的网络分析工具,科来网络分析系统。且技术交流版是完全免费的,无需注册激活。这里强烈推荐和分享给大家。这可是个网络报文分析和监控神器。有多强大&am…

【CSS系列】第七章 · CSS盒子模型,看这一篇就够了

写在前面 Hello大家好, 我是【麟-小白】,一位软件工程专业的学生,喜好计算机知识。希望大家能够一起学习进步呀!本人是一名在读大学生,专业水平有限,如发现错误或不足之处,请多多指正&#xff0…

Protobuf-net3.2.8中的protogen.exe之使用

目录 protobuf是个好东西 遇到问题 顺便研究一下命令行程序如何调试 protobuf是个好东西 protobuf是一个轻量级的数据格式,相比json,它的数据量为json的1/3,且存储方式为2进制,并进行了压缩,序列化和反序列化更快&…

效率与性能并存——离不开 Visual Studio Code 的前端开发与我

文章目录 📋前言🎯题外话:我与 VSCode 的那些事🎯VSCode 的强大之处🧩VSCode 的诞生🧩VSCode 的一些功能 🎯优与劣(简单小结)📝最后 📋前言 许久…

JVM 原理简介

JVM一直是java知识里面进阶阶段的重要部分,如果希望在java领域研究的更深入,则JVM则是如论如何也避开不了的话题,本系列试图通过简洁易读的方式,讲解JVM必要的知识点。 运行流程 我们都知道java一直宣传的口号是:一次编…

股票K线基础知识1

K线图 K线图是反映价格在某一时间周期内波动情况的图表,它由开盘价、收盘价、最高价、最低价四个要素构成,若当日收盘价高于开盘价,这表明价格处于上涨状态,此时K线图多用红色表示;若当日收盘价低于开盘价&#xff0c…

(转载)从0开始学matlab(第1天)—变量和数组

MATLAB 程序的基本数据单元是数组。一个数组是以行和列组织起来的数据集合,并且拥有一个数组名。数组中的单个数据是可以被访问的,访问的方法是数组名后带一个括号,括号内是这个数据所对应行标和列标。标量在 MATLAB 中也被当作数组来处理——…

JavaScript实现输入文字,指定输出遍数的代码

以下为实现输入文字,指定输出遍数的程序代码和运行截图 目录 前言 一、实现输入文字,指定输出遍数 1.1 运行流程及思想 1.2 代码段 1.3 JavaScript语句代码 1.4 运行截图 前言 1.若有选择,您可以在目录里进行快速查找; 2.…

Prometheus+Alertmanager+webhook-dingtalk实现钉钉告警

文章目录 一、前提准备及规划二、安装及启动2.1 Prometheus安装启动2.2 Node_export安装启动2.3 Alertmanager安装启动2.4 Webhook-dingtalk安装启动 三、配置及测试3.1 Webhook-dingtalk配置钉钉webhook地址3.2 Alertmanager配置钉钉告警3.3 Prometheus集成Alertmanager及告警…

基于Docker的深度学习环境部署以及WSL和linux镜像问题

基于Docker的深度学习环境部署 1. 什么是Docker?2. 深度学习环境的基本要求3. Docker的基本操作3.1 在Windows上安装Docker3.2 在Ubuntu上安装Docker3.3 拉取一个pytorch的镜像3.4 部署自己的项目3.5 导出配置好项目的新镜像 4. 分享新镜像4.1 将镜像导出为tar分享给…

android应用的一种图标隐藏

在Android10之前,应用程序通过调用PackageManager.setComponentEnabledSetting(componentName, PackageManager.COMPONENT_ENABLED_STATE_DISABLED, PackageManager.DONT_KILL_APP)函数来实现图标隐藏。 但是在android10之后,所有具有四大组件和需要申请…

C语言函数

C语言函数 一 函数的分类举例:*比较两个整数的大小**交换两个整数的值*(传地址) 二 参数实参形参 三 练习1.写一个函数判断一个数是不是素数2.写一个函数判断这一年是不是闰年3.写一个函数实现一个整型有序数组的二分查找4.写一个函数&#x…

两种方法教你在postman设置请求里带动态token

问题描述 在使用postman调试接口时,遇到一些需要在请求里加上token的接口,若token出现变化,需要手动修改接口的token值,带来重复的工作量,翻看postman使用手册后,我发现了两种方法可以解决这个问题。 01 …

自动化测试开发年薪30w+?我对自己的职业规划产生了质疑

咱们还是开门见山,今天我们主要讲这几个问题: 1-测试开发都干些啥? 2-为什么那么多公司都要招聘测试开发? 3-测试开发的薪资 一、测试开发是什么? 所谓测试开发,是用更为全面的技术手段来提高测试效率&…

java学习笔记——线程池、Lambda表达式

第一章 等待唤醒机制 1.1 线程间通信 概念:多个线程在处理同一个资源,但是处理的动作(线程的任务)却不相同。 比如:线程A用来生成包子的,线程B用来吃包子的,包子可以理解为同一资源&#xff0…

小米刷机小白教程最新详细版

★本篇为线刷(以修补boot的方式刷入面具) 如果你用的是小米手机,想获取面具root,看这一篇就够了,即使你是小白 必应搜索醉里博客http://202271.xyz?xiaomi 原创不易,谢绝转载,如果本教程有帮…

Linux系统优化

一、系统启动流程 1.centos6 centos6开机启动流程,传送门 2.centos7启动流程 二、系统启动运行级别 2.1 什么是运行级别 运行级别:指操作系统当前正在运行的功能级别; [rootweb01 ~]# ll /usr/lib/systemd/system lrwxrwxrwx. 1 root root…

Linux指令2

目录 一、 more指令二、 less指令(非常重要)三、时间相关的指令四、cal指令五、find指令(非常重要)六、grep命令七、zip和unzip指令八、tar指令(十分重要)打包/解包,不解压它,直接看…