Llama 架构分析

news2024/9/20 15:23:38

从代码角度进行Llama 架构分析

  • Llama 架构分析
    • 前言
    • Llama 架构分析
      • 分词
      • 网络主干
        • DecoderLayer
          • Attention
          • MLP
      • 下游任务
        • 因果推理
        • 文本分类

Llama 架构分析

前言

Meta 开发并公开发布了 Llama系列大型语言模型 (LLM),这是一组经过预训练和微调的生成文本模型,参数规模从 70 亿到 700 亿不等。

在大多数任务中,LLaMA-13B要比GPT-3(175B)的性能要好,LLaMA-65B和组好的模型Chinchilla-70B以及PaLM-540B的实力相当。

Llama 架构分析

分词

分词部分主要做的是利用文本分词器对文本进行分词

在这里插入图片描述

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
text = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(text, return_tensors="pt")

网络主干

主干网络部分主要是将分词得到的input_ids输入到embedding层中进行文本向量化,得到hidden_states(中间结果),然后输入到layers层中,得到hidden_states(中间结果),用于下游任务。

在这里插入图片描述

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
DecoderLayer

主干网络的layers层就是由多个DecoderLayer组成的,由num_hidden_layers参数决定,一般我们说的模型量级就取决于这个数量,7b的模型DecoderLayer层的数量是32。

DecoderLayer层中又包含了Attention层和MLP层,主要的一个思想是利用了残差结构。

如下图所示,分为两个部分

第一部分

  • 首先,将hidden_states(文本向量化的结构)进行复制,即残差
  • 归一化
  • 注意力层
  • 残差相加

第二部分

  • 首先将第一部分得到的hidden_states进行复制,即残差
  • 归一化
  • MLP层
  • 残差相加

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

#复制一份
residual = hidden_states
#归一化
hidden_states = self.input_layernorm(hidden_states)

#注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    padding_mask=padding_mask,
)
#加上残差
hidden_states = residual + hidden_states

#复制一份
residual = hidden_states
#归一化
hidden_states = self.post_attention_layernorm(hidden_states)
#mlp
hidden_states = self.mlp(hidden_states)
#加上残差
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
    outputs += (self_attn_weights,)

if use_cache:
        outputs += (present_key_value,)

return outputs
Attention

进行位置编码,让模型更好的捕捉上下文信息

在这里插入图片描述

#经过线性层
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

#多头注意力形状变换
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]

#计算cos、sin
#计算旋转位置嵌入
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

#计算权重
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

#加上掩码
attn_weights = attn_weights + attention_mask
#计算softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = self.o_proj(attn_output)

MLP

mlp层的主要作用是应用非线性激活函数和线性投影。

  • 首先将attention层得到的结果经过两个线性层得到gate_proj和up_proj
  • gate_proj经过激活函数,再和up_proj相乘
  • 最后经过一个线性层得到最后的结果

在这里插入图片描述

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

下游任务

因果推理

所谓因果推理,就是回归任务。

在这里插入图片描述

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
文本分类

即分类任务

在这里插入图片描述

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316575.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AWS向量数据库Amazon OpenSearch Service使用测评

前言 在大模型盛行的当今,选择适宜的数据库显得尤为重要。因为你需要面对海量训练数据,快速的检索至关紧要,以及对于存储的要求也是至关重要的。对于海量的数据查询和存储是需要巨大的算力支持。向量数据库常用在一些图像文本或者视频的生成…

了解 Flutter 3.16 功能更新

作者 / Kevin Chisholm 我们在季度 Flutter 稳定版发布会上带来了 Flutter 3.16,此版本包含诸多更新: Material 3 成为新的默认主题、为 Android 带来 Impeller 的预览版、允许添加适用于 DevTools 的扩展程序等等,以及同步推出 Flutter 休闲游戏工具包重…

php查询数据库,并通过表格展示

第一步:创建数据库 创建一个数据库php-crud 第二步:创建数据库表 在数据库php-crud下创建一个歌曲表song /*Navicat Premium Data TransferSource Server : MariaDBSource Server Type : MariaDBSource Server Version : 100605 (10.6.5-M…

PrimDiffusion:3D 人类生成的体积基元扩散模型NeurIPS 2023

NeurIPS2023 ,这是一种用于 3D 人体生成的体积基元扩散模型,可通过离体拓扑实现明确的姿势、视图和形状控制。 PrimDiffusion 对一组紧凑地代表 3D 人体的基元执行扩散和去噪过程。这种生成建模可以实现明确的姿势、视图和形状控制,并能够在…

linux 开机启动流程

1.打开电源 2.BIOS 有时间和启动方式 3.启动Systemd 其pid为1 4.挂载引导分区 /boot 5.启动各种服务 如rc.local

Ps:形状工具 - 描边选项

在形状工具的工具选项栏或“属性”面板中,单击“设置形状描边类型” Set shape stroke type菜单图标可打开“描边选项” Stroke Options面板。 描边预设 Stroke Type 默认列出了实线、虚线和点线三种类型的描边,单击可应用。 自己创建并存储的描边类型&a…

蓝桥杯专题-真题版含答案-【国庆星期日】【三色棋】【蒙地卡罗法求 PI】【格雷码(Gray Code)】

Unity3D特效百例案例项目实战源码Android-Unity实战问题汇总游戏脚本-辅助自动化Android控件全解手册再战Android系列Scratch编程案例软考全系列Unity3D学习专栏蓝桥系列ChatGPT和AIGC 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分…

selenium-grid4.3.0两种模式记录

selenium-grid4.3.0两种模式记录 本文运行,需要提前配置好Java11以及安装好Chrom、Firefox、Safari其中一个浏览器,如果是Chrom、Firefox需要下载对应版本的驱动,并给 webdriver 配置环境变量,Safari浏览器Mac系统会自带&#xf…

SQL进阶理论篇(八):SQL查询的IO成本

文章目录 简介数据库缓冲池查看缓冲池的大小数据页加载的三种方式通过 last_query_cost 统计 SQL 语句的查询成本总结参考文献 简介 本节将介绍磁盘IO是如何加载数据的,重点介绍一下数据库缓冲池的概念。主要包括: 什么是数据库缓冲池,它在…

CSS学习笔记整理

CSS 即 层叠样式表/CSS样式表/级联样式表,也是标记语言, 用于设置HTML页面中的文本内容(字体、大小、对齐方式等)、图片的外形(宽高、边框样式、边距)以及版面的布局和外观显示样式 目录 准备工作 Chrome调…

关于反射机制的简单理解

1、反射的简单认识 1.1 定义 Java的反射(reflection)机制是在运行状态中,对于任意一个类,都能够知道这个类的所有属性和方法;对于任意一个对象,都能够调用它的任意方法和属性,既然能拿到,那么我…

持续集成交付CICD:Jenkins使用GitLab共享库实现基于Ansible的CD流水线部署前后端应用

目录 一、实验 1.部署Ansible自动化运维工具 2.K8S 节点安装nginx 3.Jenkins使用GitLab共享库实现基于Ansible的CD流水线部署前后端应用 二、问题 1.ansible安装报错 2.ansible远程ping失败 3. Jenkins流水线通过ansible命令直接ping多台机器的网络状态报错 一、实验 …

Photoshop插件3D Map Generator Geo的使用记录1(版本说明、安装卸载使用和高程数据生成3D地形图的准备工作)

3D Map Generator是一款强大的地图创建和定制化工具,具有以下特点和功能: 快速创建3D地图:用户可以通过该工具快速创建出高质量的3D地图,而无需具备专业的GIS或PS技能。支持多种图层类型:3D Map Generator支持多种图层…

pytest之allure测试报告03:allure动态自定义报告

1、测试用例模块中引入allure:import allure 2、yaml文件中定义添加title、story的值: 3、测试用例中读取调用。eg:allure.dynamic.title() 4、运行报告查看:成功动态展示yaml文件中配置的story、title

WPF-UI HandyControl 控件简单实战

文章目录 前言UserControl简单使用新建项目直接新建项目初始化UserControlGeometry:矢量图形额外Icon导入最优解决方案 按钮Button切换按钮ToggleButton默认按钮图片可切换按钮加载按钮切换按钮 单选按钮和复选按钮没有太大特点,就不展开写了总结 DataGrid数据表格G…

用标记接口定义类型

标记接口是不含有任何方法的接口,它的目的是通过将特定接口应用于类来为该类添加类型信息。以下是一个示例: public interface Drawable {// 标记接口,不包含任何方法 }public class Circle implements Drawable {private int radius;public…

过滤器和监听器及应用

Filter及应用 Filter有什么用?一、Filter处理中文乱码二、监听器,统计网站在线人数1.监听器引入2.统计网站在线人数 三、Filter实现权限拦截 Filter有什么用? Filter:过滤器,可以用来过滤网站的数据。 比如处理中文乱码,每次写servlet&…

k8syaml提供的几个有意思的功能,Kubernetes在线工具网站

k8syaml.cn 提供的几个有意思的功能。 一、yaml资源快速生成 之前编写operator的helm的时候就需要自己写deployment、service、configmap这些资源,那么多字段也记不清,都是先找个模版,然后copy改改,再看官方文档,添加…

LeetCode(66)二叉树的最大深度【二叉树】【简单】

目录 1.题目2.答案3.提交结果截图 链接: 二叉树的最大深度 1.题目 给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7]…

[已解决】uniapp内置插件,editor富文本报错(附quill.min.js、image-resize.min.js文件)

在使用uni-app运行内置插件editor时,无法输入内容,控制台报错 原因:查看官网得知,需动态引入quill.min.js、image-resize.min.js文件 解决方法: 1.下载quill.min.js、image-resize.min.js到项目static/eidtor文件中 链…