Llama架构及代码详解

news2024/12/28 5:41:52

Llama的框架图如图:
在这里插入图片描述
源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下:

Llama的整体组成

由上图可知,Llama整体是由1个embedding层,n个transformer层,和1个RMSNorm层组成的,所以顶层代码如下:
顶层

class Llama(torch.nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
       self.config = config
        # embedding层
        self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
        # RMSNorm
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        # n层Transformer
        self.layers = torch.nn.ModuleList()
        for i in range(self.config.n_layers):
            self.layers.append(TransformerBlock(config))


    def forward(self, tokens):
        # 进行token的嵌入编码
        h = self.tok_embeddings(tokens)
        # decoder架构需要生成一个mask
        seqlen = h.shape[1]
        mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)
        mask = torch.triu(mask, diagonal=1)
        # 进行n层Transformer
        for i in range(self.config.n_layers):
            h = self.layers[i](h, mask)
        # 进行RMSNorm
        token_embeddings = self.norm(h)
        return token_embeddings

中层
我们首先进行RMSNorm的复现

class RMSNorm(torch.nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(dim))

    def _norm(self, tensor):
        return tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, tensor):
        output = self._norm(tensor)
        return output * self.weight

然后对Transformer进行复现,在Transformer中,Transformer包括两个RMSNorm层,一个多头attention层,一个全连接层。

class TransformerBlock(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 多头注意力层
        self.attention = Attention(config)
        # Norm层
        self.attention_normal = RMSNorm(config.dim, config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
        # 全连接层
        self.ffn = FeedForwad(self.config.dim, self.config.dim * 4)

    def forward(self, embeddings, mask):
        # norm
        h = self.attention_normal(embeddings)
        # attention
        h = self.attention(h, mask)
        # add & norm
        h = self.ffn_norm(h + embeddings)
        # fnn
        f = self.ffn(h)
        # add
        return f + h

底层
在多头attention中,首先需要对token的嵌入进行空间映射,多头拆分,旋转位置编码,分数计算等操作

class Attention(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.n_head = config.n_heads
        self.dim = config.dim // self.n_head

        self.k = torch.nn.Linear(config.dim, config.dim)
        self.q = torch.nn.Linear(config.dim, config.dim)
        self.v = torch.nn.Linear(config.dim, config.dim)

    def forward(self, embeddings, mask):
        bsz, seq_len, dim = embeddings.shape

        k_embeddings = self.k(embeddings)
        q_embeddings = self.q(embeddings)
        v_embeddings = self.v(embeddings)
        n_q_embeddings = q_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
        n_k_embeddings = k_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
        n_v_embeddings = v_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)

        rotated_n_q_embeddings = compute_rotated_embedding(n_q_embeddings, self.dim, seq_len, self.config.rope_theta)
        rotated_n_k_embeddings = compute_rotated_embedding(n_k_embeddings, self.dim, seq_len, self.config.rope_theta)

        scores = torch.nn.functional.softmax(mask + rotated_n_q_embeddings @ rotated_n_k_embeddings.transpose(-1, -2)
                               / math.sqrt(self.dim), dim=-1)

        n_embeddings = scores @ n_v_embeddings
        embeddings = n_embeddings.permute(0, 2, 1, 3).reshape(bsz, -1, self.config.dim)

        return embeddings
class FeedForwad(torch.nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(dim, hidden_dim)
        self.linear2 = torch.nn.Linear(dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, dim)

    def forward(self, embeddings):
        gate = torch.nn.functional.silu(self.linear1(embeddings))
        up_proj = self.linear2(embeddings) * gate
        return self.linear3(up_proj)

最后,我们复现旋转位置编码,至此我们捋清了llama的所有结构!

def compute_rotated_embedding(embedding, dim, m, base):
    # 计算所有嵌入位置的旋转角度
    all_theta = compute_all_theta(dim, m, base)
    # 旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
    # 1、将嵌入投影到复数平面
    embedding_real_pair = embedding.reshape(*embedding.shape[:-1], -1, 2)
    embedding_complex_pair = torch.view_as_complex(embedding_real_pair)
    # 2、将旋转角度投影到复数平面
    all_theta = all_theta[: embedding.shape[-2]]
    theta_complex_pair = torch.polar(torch.ones_like(all_theta), all_theta)
    # 3、旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
    rotated_complex_embedding = embedding_complex_pair * theta_complex_pair
    # 4、将复数平面的嵌入投影到实数平面
    rotated_real_embedding = torch.view_as_real(rotated_complex_embedding)
    rotated_real_embedding = rotated_real_embedding.reshape(*embedding.shape[:-1], -1)
    return rotated_real_embedding

def compute_all_theta(dim, m, base):
    theta = 1 / (base ** (torch.arange(0, dim / 2).float() / (dim / 2)))
    m = torch.arange(0, m)
    all_theta = torch.outer(m, theta)
    return all_theta

附录:llama的config参数

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_batch_size: int = 32
    max_seq_len: int = 2048
    use_scaled_rope: bool = True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2241437.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

stm32在linux环境下的开发与调试

环境安装 注:文末提供一键脚本 下载安装stm32cubeclt 下载地址为:https://www.st.com/en/development-tools/stm32cubeclt.html 选择 linux版本下载安装 安装好后默认在家目录st下 > $ ls ~/st/stm32cubeclt_1.16.0 …

第T7周:Tensorflow实现咖啡豆识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: PyCharm 框 架: (二)具体步骤 1. 使…

亲测有效:Maven3.8.1使用Tomcat8插件启动项目

我本地maven的settings.xml文件中的配置&#xff1a; <mirror><id>aliyunmaven</id><mirrorOf>central</mirrorOf><name>阿里云公共仓库</name><url>https://maven.aliyun.com/repository/public</url> </mirror>…

LLM - 使用 LLaMA-Factory 微调大模型 Qwen2-VL SFT(LoRA) 图像数据集 教程 (2)

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/143725947 免责声明&#xff1a;本文来源于个人知识与公开资料&#xff0c;仅用于学术交流&#xff0c;欢迎讨论&#xff0c;不支持转载。 LLaMA-…

神经网络与Transformer详解

一、模型就是一个数学公式 模型可以描述为:给定一组输入数据,经过一系列数学公式计算后,输出n个概率,分别代表该用户对话属于某分类的概率。 图中 a, b 就是模型的参数,a决定斜率,b决定截距。 二、神经网络的公式结构 举例:MNIST包含了70,000张手写数字的图像,其中…

鲸鱼机器人和乐高机器人的比较

鲸鱼机器人和乐高机器人各有其独特的优势和特点&#xff0c;家长在选择时可以根据孩子的年龄、兴趣、经济能力等因素进行综合考虑&#xff0c;选择最适合孩子的教育机器人产品。 优势 鲸鱼机器人 1&#xff09;价格亲民&#xff1a;鲸鱼机器人的产品价格相对乐高更为亲民&…

Flink Source 详解

Flink Source 详解 原文 flip-27 FLIP-27 介绍了新版本Source 接口定义及架构 相比于SourceFunction&#xff0c;新版本的Source更具灵活性&#xff0c;原因是将“splits数据获取”与真“正数据获取”逻辑进行了分离 重要部件 Source 作为工厂类&#xff0c;会创建以下两…

路漫漫其修远兮,吾将上下而求索---第一次使用github的过程记录和个人感受

文章目录 1.仓库位置2.新建仓库3.配置仓库4.克隆和上传5.推荐文章和我的感受 1.仓库位置 这个仓库的位置就是在我们的这个个人主页的右上角&#xff1b;如果是第一次注册账号的话&#xff0c;这个主页里面肯定是不存在仓库的&#xff0c;需要我们自己手动的进行创建&#xff1…

npm list -g --depth=0(用来列出全局安装的所有 npm 软件包而不显示它们的依赖项)

您提供的命令 npm list -g --depth0 是在 Node Package Manager (npm) 的上下文中使用的&#xff0c;用来列出全局安装的所有 npm 软件包而不显示它们的依赖项。 这是它的运作方式&#xff1a; npm list -g --depth0-g: 指定列表应包括全局安装的软件包。--depth0: 限制树形结…

tdengine学习笔记

官方文档&#xff1a;用 Docker 快速体验 TDengine | TDengine 文档 | 涛思数据 整体架构 TDENGINE是分布式&#xff0c;高可靠&#xff0c;支持水平扩展的架构设计 TDengine分布式架构的逻辑结构图如下 一个完整的 TDengine 系统是运行在一到多个物理节点上的&#xff0c;包含…

K8S单节点部署及集群部署

1.Minikube搭建单节点K8S 前置条件&#xff1a;安装docker&#xff0c;注意版本兼容问题 # 配置docker源 wget https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo -O /etc/yum.repos.d/docker-ce.repo# 安装docker环境依赖 yum install -y yum-utils device-m…

以往运维岗本人面试真题分享

以下是本人面试运维岗的一些面试经历&#xff0c;在此做个记录分享 目录 TCP/IP三次握手 IPtables IPtables四表五链都是什么&#xff1f; nat端口如何做&#xff1f; 开放本机的80端口该如何做&#xff1f; 如何在单用户模式下引导Centos&#xff1f; nginx轮询模式都有…

STM32 串口输出调试信息

软硬件信息 CubeMX version 6.12.1Keil uVision V5.41.0.0 注意 串口有多种&#xff1a; TTL232485 串口的相关知识&#xff1a; 01-【HAL库】STM32实现串口打印&#xff08;printf方式) &#xff0c; 内含 TTL 和 232 区别。 我把 232 串口连进 STM32 串口助手收到的信息…

Python 三种方式实现自动化任务

在这篇文章中&#xff0c;我们将介绍一些用Python实现机器人过程自动化的包。机器人流程自动化&#xff08;Robotic process automation&#xff0c;简称RPA&#xff09;是指将鼠标点击和键盘按压自动化的过程&#xff0c;即模拟人类用户的操作。RPA用于各种应用程序&#xff0…

Android ART知多少?

Android 虚拟机 ART&#xff08;Android Runtime&#xff09;是 Android 平台上的应用程序运行时环境&#xff0c;用于执行应用程序的字节码。ART 自 Android 5.0&#xff08;Lollipop&#xff09;开始取代了 Dalvik&#xff0c;成为 Android 的默认运行时环境。本文将从以下几…

Vulnhub靶场 Billu_b0x 练习

目录 0x00 准备0x01 主机信息收集0x02 站点信息收集0x03 漏洞查找与利用1. 文件包含2. SQL注入3. 文件上传4. 反弹shell5. 提权&#xff08;思路1&#xff1a;ssh&#xff09;6. 提权&#xff08;思路2&#xff1a;内核&#xff09;7. 补充 0x04 总结 0x00 准备 下载链接&#…

软间隔支持向量机支持向量的情况以及点的各种情况

软间隔支持向量 ​ 这一节我们要回答的问题是&#xff1f;如何判断一个点是软间隔支持向量机中的支持向量&#xff0c;在硬间隔支持向量机中&#xff0c;支持向量只需要满足一个等式&#xff1a; y i ( w T x i b ) − 1 0 y_i(w^Tx_i b) -1 0 yi​(wTxi​b)−10 ​ 在软间…

PCA 原理推导

针对高维数据的降维问题&#xff0c;PCA 的基本思路如下&#xff1a;首先将需要降维的数据的各个变量标准化&#xff08;规范化&#xff09;为均值为 0&#xff0c;方差为 1 的数据集&#xff0c;然后对标准化后的数据进行正交变换&#xff0c;将原来的数据转换为若干个线性无关…

在Ubuntu 24.04 LTS上安装飞桨PaddleX

前面我们介绍了《在Windows用远程桌面访问Ubuntu 24.04.1 LTS》本文接着介绍安装飞桨PaddleX。 PaddleX 3.0 是基于飞桨框架构建的一站式全流程开发工具&#xff0c;它集成了众多开箱即用的预训练模型&#xff0c;可以实现模型从训练到推理的全流程开发&#xff0c;支持国内外多…

Web_前端_HTML入门学习的案例案例1

HTML入门学习的案例 来源: HTML入门学习的案例_给学生讲html内容案例-CSDN博客 案例1&#xff1a;hello.html <html><body><title>html技术</title></body><body>hello</body> </html>&#xff08;但是有乱码&#xff09; …