Llama模型结构解析(源码阅读)

news2025/1/17 1:34:42

目录

  • 1. LlamaModel整体结构流程图
  • 2. LlamaRMSNorm
  • 3. LlamaMLP
  • 4. LlamaRotaryEmbedding

  • 参考资料:
    https://zhuanlan.zhihu.com/p/636784644
    https://spaces.ac.cn/archives/8265 ——《Transformer升级之路:2、博采众长的旋转式位置编码》

前言:本次阅读代码位置,在transformers库底下的modeling_llama.py,具体位置在:transformers/models/llama/modeling_llama.py,如下图所示:在这里插入图片描述

1. LlamaModel整体结构流程图

在这里插入图片描述

2. LlamaRMSNorm

  • 代码如下
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

        return (self.weight * hidden_states).to(input_dtype)
  • RMSNorm的公式如下所示:
    x i 1 n ∑ i = 1 n x i 2 + e p s ∗ w e i g h t i \frac{x_i}{\sqrt{\frac{1}{n}\sum\limits_{i=1}^{n}{x_i}^2 + eps}} * weight_i n1i=1nxi2+eps xiweighti

    • 其中,公式与代码的对应关系如下:
      在这里插入图片描述

3. LlamaMLP

  • 代码如下:
class LlamaMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  • 流程图:
    在这里插入图片描述

  • 其中输入为x,输出为y

  • 代码中intermediate_size一般比hidden_size大,我们通过在jupyter notebook中打印Llama-13B的模型,可以看到如下所示:
    在这里插入图片描述

  • 总结:MLP模块就是几个nn.Linear的组合

4. LlamaRotaryEmbedding

  • 代码如下

class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )
  • 具体的使用,还调用了另外两个函数,如下所示:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
    
  • 注意这里的实现跟原始推导有点区别,这里实现的方式如下图所示:
    在这里插入图片描述

  • 原始推导如下图所示:
    在这里插入图片描述
    具体可以查看作者的博客:👉戳我👈

  • 总结:RoPE就是在attention计算时,K跟Q做内积之前,先给各自注入位置信息。

结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/944220.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安科瑞风力发电场集中监控系统解决方案-安科瑞黄安南

作为清洁能源之一,风力发电场近几年装机容量快速增长。8月17日,国家能源局发布1-7月份全国电力工业统计数据。截至7月底,全国累计发电装机容量约27.4亿千瓦,同比增长11.5%。其中,太阳能发电装机容量约4.9亿千瓦&#x…

Oracle数据传输加密方法

服务器端“dbhome_1\NETWORK\ADMIN\”sqlnet.ora文件中添加 SQLNET.ENCRYPTION_SERVER requested SQLNET.ENCRYPTION_TYPES_SERVER (RC4_256) 添加后新的链接即刻生效,服务器无需重新启动。 也可以通过Net manager管理工具添加 各个参数含义如下: 是…

Web开发模式、API接口、restful规范、序列化和反序列化、drf安装和快速使用

一 Web开发模式 1. 前后端混合开发模式 前后端混合开发模式是一种开发方式,将前端和后端的开发工作结合在一起,以加快项目的开发速度和 提高协作效率。这种模式通常用于快速原型开发、小型项目或敏捷开发中。在前后端混合开发模式中,前端和…

【MyBatis】自定义resultMap三种映射关系

目录 一、一对一映射(One-to-One) 1.1 表关系 1.2 resultMap设置自定义映射 二、一对多映射(One-to-Many) 2.1 创建实体 2.2 级联方式处理映射关系 2.3 定义SQL 2.4 OrderMapper接口 2.5 编写业务逻辑层 2.6 Junit测试…

港联证券:游资爆炒中电环保,还有谁在蹭核污染防治概念?

8月28日,核污染防治概念股持续大涨,建工修复(300958.SZ)、捷强配备(300875.SZ)、东方园林(002310.SZ)、华盛昌(002980.SZ)等涨停。 中小市值的概念股成为游资…

人工智能学习专栏

这个专栏就专门用来记录自己的深度学习的历程吧。从做MCU开始、Soc、Linux系统转行到AI领域,其过程是痛苦的。至少数学这块,那是花了很多时间去从头去学。但是还是有很多不懂的地方。坚持!!!!

03 最长连续序列

最长连续序列 题解 哈希(O(n)) 给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。 请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 题解 哈希(O(n)) class Solution { public:int long…

升级iOS17后iPhone无法连接App Store怎么办?

最近很多用户反馈,升级最新iOS 17系统后打开App Store提示"无法连接",无法正常打开下载APP。 为什么升级后无法连接到App Store?可能是以下问题导致: 1.网络问题导致App Store无法正常打开 2.网络设置问题 3.App Sto…

微信报修系统有什么优势?怎么提升企业维修工作效率与管理水平?

随着智能化时代的到来,企业、事业单位的现代化设备数量和种类不断增加,原本繁琐的报修、填写记录、检修管理等工作得以简化。从发起报修到维修,以及维修之后给予评价的整个过程,通过手机微信报修系统均能看到,既省时又…

算法---二叉树中的最大路径和

题目 二叉树中的 路径 被定义为一条节点序列,序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点,且不一定经过根节点。 路径和 是路径中各节点值的总和。 给你一个二叉树的根节点 root &…

IC698CRE040 GE 实现跨多个UPS设备的随处可见性

IC698CRE040 GE 实现跨多个UPS设备的随处可见性 通过人工智能、digital twin技术、由高级分析支持的人类洞察力以及独立于供应商的工业软件,效率和敏捷性正在发生巨大变化。施耐德电气通过为未来打造的弹性和可持续解决方案实现下一代工业自动化。 EcoStruxure自…

出现ZooKeeper JMX enabled by default这种错误的解决方法

系列文章专栏 学习以来遇到的bug/问题专栏 文章目录 系列文章专栏 前言 一 问题描述 二 解决方法 2.1 可能的原因分析 2.2 小编的问题解决方法 First:检查/etc/profile里面zookeeper的环境变量配置 Second:检查 zookeeper/conf/zoo.cfg里面的d…

拆解即时通讯行销,如何提升讯息开启率达300%?

图片来源:SaleSmartly官网 科技日新月异,今时今日商家均转战网络世界,开设网店售卖产品或服务,不少人都会转用即时通讯(Instant Messaging,简称IM)软件来和客户联络和宣传,因为即时通…

Unity3D 如何在ECS架构下,用Unity引擎进行游戏开发详解

前言 Unity3D是一款强大的游戏引擎,它提供了丰富的功能和工具,可以帮助开发者快速构建高质量的游戏。而Entity Component System(ECS)是Unity3D中一种新的架构模式,它可以提高游戏的性能和可扩展性。本文将详细介绍在…

自动化运维:Ansible脚本之playbook剧本

目录 一、理论 1.playbooks 2.YAML 3.使用ansible批量安装apache服务 4.定义、引用变量 5.指定远程主机sudo切换用户 6.when条件判断 7.迭代 8.Templates 模块 9.tags 模块 10.Roles 模块 二、实验 1.使用ansible批量安装apache服务 2.定义、引用变量…

S型曲线规划

s #include "stdio.h"typedef struct S_CTRL{ #define SSPD_BUF_LEN 100struct{float aMax;float aMin;float vMax;float J;/* 加加速度 */int t[7];int T[7];int tMax;}in;struct{float accBuf[SSPD_BUF_LEN];float decBuf[SSPD_BUF_LEN];long S[7];long V[7];}o…

电商API接口的研发和应用!

API(Application Programming Interface,应用程序编程接口)指的是为不同的软件应用程序提供编程接口的一组协议、规则以及工具的集合,以便它们能够互相交互,实现数据通信和功能调用。API已成为了现代软件开发和商业应用…

Linux系统部署部署excalidraw-cn白板工具

Linux系统部署部署excalidraw-cn白板工具 一、excalidraw-cn介绍二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍2.3 Yarn介绍 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本3.3 检查系统是否安装yarn 四、部署Node.js 环境4.1 下载Node.js安装包4.2 解压N…

Docker技术--Docker中的网络问题

1.docker中的网络通信 如果想要弄清楚docker中的网络通信问题,其实需要弄清楚这几个问题就可以:容器与容器之间的通信、容器与外部网络之间的通信、外部网络与容器之间的通信。 -a:容器与容器之间的通信,如下所示: 在默认情况下,docker使用网桥(Bridge模式)与NAT通信。这…

Java监听mysql的binlog 报错解决办法

报错:com.github.shyiko.mysql.binlog.network.AuthenticationException: Client does not support authentication protocol requested by server; consider upgrading MySQL client 解决方案:在mysql中执行以下命令 alter user rootlocalhost identi…