llama3 结构详解

news2025/2/6 8:45:53

文章目录

  • 1. Llama3 整体结构
  • 2. 模块详解
    • 2.1 模块1: Embeddings
    • 2.2 模块2: RoPE
    • 2.3 模块3: Transformer Block
    • 2.4 模块4: RMSNorm
    • 2.5 模块5: Attention
    • 2.6 模块6: ADD
    • 2.7 模块7: FFN
    • 2.8 模块8: Linear

1. Llama3 整体结构

  llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。

在这里插入图片描述

注: 下一节中会参照上图中 红色圆形序号 讲解各模块。

2. 模块详解

2.1 模块1: Embeddings

  llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。

2.2 模块2: RoPE

  这部分今天先不写,主要是写不完,公式太多了。。。

2.3 模块3: Transformer Block

  Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。Transformer Block有模块4、5、6、7组成,具体内容见对应模块。

2.4 模块4: RMSNorm

  RSMNorm 是在 layer normalization 基础上优化而来,所以先简单回顾下layer normalization。(详细介绍见《Transformer(二)–论文理解:transformer 结构详解》 2.4节)
  layer normalization 是根据下面的公式对 x x x的分布进行调整。
x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=astd+epsxx+b
其中, x ‾ \overline{x} x是均值, s t d std std是标准差, e p s eps eps为一个很小的数,防止分母为零。 a a a b b b为参数, b b b可以为零。
  我们现在来看看RMSNorm做了什么优化呢,其实他对上面的试子 x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=astd+epsxx+b进行了简化。RMSNorm的计算公式如下:
a ‾ i = a i R M S ( a ) g i , w h e r e R M S ( a ) = 1 n Σ i = 1 n a i 2 \overline{a}_i=\frac{a_i}{RMS(a)}g_{i}, \quad where \quad RMS(a) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{a^{2}_{i}}} ai=RMS(a)aigi,whereRMS(a)=n1Σi=1nai2

  从上式可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( a i − a ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({a_i - \overline{a})}} std=n1Σi=1n(aia) )。这种简化在计算效率上有一定提高,且原始论文也说了,在效果上没有明显影响。

下面附上meta llama3中RMSNorm的源码,方便大家理解。

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2.5 模块5: Attention

  llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。
  我们知道,在《Attention is all you need》一文中,作者为了提高计算效率,提出了MHA技术,思想是采用分而治之的策略,把K、Q、V 对应的切分为若干个短向量,然后使用Scaled Dot-Product Attention 计算出attention score后,再把结果拼接起来,从而避免了超大向量乘法的计算消耗,从而提高了计算效率。如下图所示。
在这里插入图片描述

  然而,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
在这里插入图片描述
  说了半天,其实在源码层次来就,就是在计算Scaled Dot-Product Attention之前对query进行个分组,组内共享一套Key和value。下面是meta llama3中的Attention类,方便大家理解。

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
		.
		.
		.
    

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        # 以下是Scaled Dot-Product Attention的计算
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

2.6 模块6: ADD

   此模块做了个类似残差的操作,但与残差不同的是,不是用输入减去输出,而是用输入加上输出。具体操作就是把模块4的输入与模块5的输出做加法运算。

2.7 模块7: FFN

  由3个Linear组成的FeedForward网络,这里的激活函数使用的siLU。siLU的数学公式如下:
s i l u ( x ) = x ∗ σ ( x ) ,    w h e r e   σ ( x )   i s   t h e   l o g i s t i c   s i g m o i d . silu(x)=x*\sigma(x), \ \ where\ \sigma(x)\ is\ the\ logistic\ sigmoid. silu(x)=xσ(x),  where σ(x) is the logistic sigmoid.

函数的激活曲线如下图:
在这里插入图片描述
在里注意下,siLU 还有一个名字叫“swish function”,这个在 pytorch 的官方文档中有说明。

下面给出主要源码。


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        .
        .
        .
  

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

2.8 模块8: Linear

  此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。

 self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2056503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机组成原理】二、数据的表示和运算:1.数值与编码(十进制二进制转换、BCD码、ASCII码、汉字编码、奇偶校验码、循环冗余检测CRC、海明码)

二、数据的表示和运算 文章目录 二、数据的表示和运算1.数值与编码1.1数据存储和排列❗1.2十进制转换1.2.1整数1.2.2小数 1.3二进制转换1.3.1 B->O1.3.2 B->H 1.4真值&机器数1.5 BCD码1.6 ASCII码1.7汉字与GBK1.8 UTF1.9检错码1.9.1奇偶校验码1.9.2循环冗余检测CRC1.…

鸿蒙Harmony实战:常用命令交互工具—“hvigorw”

hvigor通过hvigorw工具&#xff0c;实现命令行交互。 命令行使用方式 hvigorw [taskNames...] <options> 常用命令 查询 选项 说明 -h, --help 打印hvigor的命令帮助信息。 -v, --version 打印hvigor版本信息。 编译构建 选项 说明 clean 清理构建产物buil…

启动团队活力:5款互动游戏助力新人快速融入

在加入新团队时&#xff0c;很多人都会感到尴尬和不适应。作为团队的领导者&#xff0c;帮助新成员顺利融入团队是至关重要的。组织一场“破冰游戏”是一个有效的策略&#xff0c;不仅可以活跃团队气氛&#xff0c;还能促进成员之间的交流和理解。这时候&#xff0c;团队的领导…

ReFT: reasoning with reinforced Fine-Tuning

从一个question中看到多种多样的cot&#xff0c;都可以从中学习。 offline self-training 数据的质量是模型自己来定义的。 思考增加或者减少一条数据&#xff0c;对于模型训练的影响。 用influence function来衡量新增一条数据对于模型训练的整体的影响。 高质量的数据能够…

深度学习Day-30:CGAN入门丨生成手势图像丨可控制生成

&#x1f368; 本文为&#xff1a;[&#x1f517;365天深度学习训练营] 中的学习记录博客 &#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制] 要求&#xff1a; 结合代码进一步了解CGAN学习如何运用生成好的生成器生成指定图像 一、 基础配置 语言环境&#x…

功能测试与自动化测试详解

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 什么是自动化测试? 自动化测试是指利用软件测试工具自动实现全部或部分测试&#xff0c;它是软件测试的一个重要组成 部分&#xff0c;能完成许多手工测试无法实…

【C++】————智能指针

作者主页&#xff1a; 作者主页 本篇博客专栏&#xff1a;C 创作时间 &#xff1a;2024年8月20日 一&#xff0c;什么是智能指针 在C中没有垃圾回收机制&#xff0c;必须自己释放分配的内存&#xff0c;否则就会造成内存泄露。解决这个问题最有效的方法是使用智能指针&…

传染病防控宣传小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;防控知识管理&#xff0c;医院信息管理&#xff0c;健康上报管理&#xff0c;医疗捐赠管理&#xff0c;捐赠信息管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首…

力扣面试经典算法150题:买卖股票的最佳时机 II

买卖股票的最佳时机 II 今天的题目是力扣面试经典150题中的数组的中等难度题&#xff1a;买卖股票的最佳时机 II。 题目链接&#xff1a;https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-ii/description/?envTypestudy-plan-v2&envIdtop-interview-150 问…

EfficientFormer 系列算法

1. EfficientFormer V1 模型 论文地址&#xff1a;https://proceedings.neurips.cc/paper_files/paper/2022/file/5452ad8ee6ea6e7dc41db1cbd31ba0b8-Paper-Conference.pdf EfficientFormer V1 基于 ViT 的模型中使用的网络架构和具体的算子&#xff0c;找到端侧低效的原因。然…

深入剖析资产负债率与净资产收益率,掌握财务报表解读技巧

一、概述 财务报表中蕴含了丰富的信息&#xff0c;如果我们在解读时没有清晰的思路&#xff0c;忽略重点&#xff0c;就很容易被庞杂的数据搞得晕头转向。本文将从几个关键指标出发&#xff0c;包括资产负债率的分析、净资产收益率的解读&#xff0c;以及如何计算销售复合增长…

企业高性能web服务器——nginx

一、web基础介绍 Apache 和 Nginx 是当今为互联网提供动力的最流行的Web 服务器。 1.1、apache服务器 1.1.1、Apache prefork 模型 预派生模式&#xff0c;有一个主控制进程&#xff0c;然后生成多个子进程&#xff0c;使用select模型&#xff0c;最大并发1024每个子进程有一…

萌啦数据ozon怎么用,萌啦数据ozon使用教程

在跨境电商的浩瀚蓝海中&#xff0c;Ozon作为俄罗斯及独联体地区领先的电商平台&#xff0c;正吸引着越来越多中国卖家的目光。而“萌啦数据”作为专为跨境电商卖家打造的数据分析工具&#xff0c;其针对Ozon平台的功能更是让众多商家如虎添翼。今天&#xff0c;我们就来详细探…

后悔和父母出游的年轻人,正在计划带宠物旅行

文 | 螳螂观察 作者 | 青月 美编 |赵倩 相比于和父母一起出门远游&#xff0c;现在越来越多的95后“铲屎官”似乎更愿意和自家的宠物们组“旅游搭子”。 这听起来可能有些刺耳&#xff0c;但其实是当下很多年轻人的心声。 “带父母一起去北京玩&#xff0c;本来打算第二天…

【 每日一题 | 计算机网络】定长子网划分

重要知识点讲解 我们首先需要了解一下无分类CIDR的编址格式x.x.x/24&#xff0c;表示有24位的网路号&#xff0c;那么相应的主机号为32-248位子网掩码&#xff08;很重要&#xff09;&#xff0c;用来表示IP地址中标识网络号以及子网号的&#xff0c;也就是说如果要进行子网划…

鸿蒙内核源码分析(中断切换篇) | 系统因中断活力四射

关于中断部分系列篇将用三篇详细说明整个过程. 中断概念篇 中断概念很多&#xff0c;比如中断控制器&#xff0c;中断源&#xff0c;中断向量&#xff0c;中断共享&#xff0c;中断处理程序等等.本篇做一次整理.先了解透概念才好理解中断过程.用海公公打比方说明白中断各个概念…

Windows 环境下 Go 语言使用第三方压缩包 gozstd 的报错处理

该文章主要记录在windows平台用go语言使用gozstd包时&#xff0c;遇到的错误及处理过程&#xff08;踩坑之旅&#xff09;&#xff01; 一、gozstd简介 gozstd是一个针对Zstandard&#xff08;简称Zstd&#xff09;的Go语言包装器&#xff0c;它提供了简单且高效的API&#xf…

金山云Q2调整后EBITDA率提升至3.2% 高质量发展驱动经营质效双增

8月20日&#xff0c;金山云公布了2024年第二季度业绩。 季度内&#xff0c;金山云整体业绩延续向好态势&#xff0c;实现收入规模、盈利能力、经营现金流的联动共赢。财报显示&#xff0c;金山云Q2营收18.9亿元&#xff0c;公有云实现收入12.3亿元&#xff0c;行业云实现收入6…

The Sandbox 新提案: 2024 年亚洲和拉丁美洲区块链活动预算

理事会建议&#xff1a; 积极 &#x1f642; 内容 此提案请求为2024年第四季度&#xff0c;The Sandbox 在东南亚和拉丁美洲的主要区块链活动中的激活分配 94,500 美元的 SAND 倡议预算。&#xff08;具体活动列表见下方活动描述&#xff09; 原因 区域团队希望在这些现场活…

国际校企合作|深信服、常州信息职业技术学院、马来西亚汽车工业大学三方国际化人才培养合作签约仪式圆满成功

2024年8月19日&#xff0c;深信服科技股份有限公司与常州信息职业技术学院、马来西亚汽车工业大学正式签署了具有里程碑意义的国际校企合作协议。此次签约不仅是“教随产出、校企同行”理念的一次成功实践&#xff0c;更是中马两国友谊与合作的象征。 常州信息职业技术学院党委…