Mistral MOE架构全面解析

news2024/11/27 10:38:46

从代码角度理解Mistral架构

  • Mistral架构全面解析
    • 前言
    • Mistral 架构分析
      • 分词
      • 网络主干
        • MixtralDecoderLayer
          • Attention
          • MOE
          • MLP
      • 下游任务
        • 因果推理
        • 文本分类

Mistral架构全面解析

前言

Mixtral-8x7B 大型语言模型 (LLM) 是一种预训练的生成式稀疏专家混合模型。在大多数基准测试中,Mistral-8x7B 的性能优于 Llama 2 70B。

Mixtral 8x7B 是 Mistral AI 全新发布的 MoE 模型,MoE 是 Mixture-of-Experts 的简称,具体的实现就是将 Transformer 中的 FFN 层换成 MoE FFN 层,其他部分保持不变。在训练过程中,Mixtral 8x7B 采用了 8 个专家协同工作,而在推理阶段,则仅需激活其中的 2 个专家。这种设计巧妙地平衡了模型的复杂度和推理成本,即使在拥有庞大模型参数的情况下,也能保证高效的推理性能,使得 MoE 模型在保持强大功能的同时,也具备了更优的实用性和经济性。

  • 在大多数基准测试中表现优于Llama 2 70B
  • 甚至足以击败GPT-3.5上下文窗口为32k
  • 可以处理英语、法语、意大利语、德语和西班牙语
  • 在代码生成方面表现优异

huggingface上给出基本的加载方法

from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "mistralai/Mixtral-8x7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)

text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt")

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

在这里插入图片描述

Mistral 架构分析

其它结构和llama的一模一样,知道llama结构的话,省流直接看MOE部分。

分词

分词部分主要做的是利用文本分词器对文本进行分词

在这里插入图片描述

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
text = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(text, return_tensors="pt")

网络主干

主干网络部分主要是将分词得到的input_ids输入到embedding层中进行文本向量化,得到hidden_states(中间结果),然后输入到layers层中,得到hidden_states(中间结果),用于下游任务。

在这里插入图片描述

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
MixtralDecoderLayer

主干网络的layers层就是由多个MixtralDecoderLayer组成的,由num_hidden_layers参数决定,一般我们说的模型量级就取决于这个数量,7b的模型DecoderLayer层的数量是32。

MixtralDecoderLayer层中又包含了Attention层和MOE层,主要的一个思想是利用了残差结构。

如下图所示,分为两个部分

第一部分

  • 首先,将hidden_states(文本向量化的结构)进行复制,即残差
  • 归一化
  • 注意力层
  • 残差相加

第二部分

  • 首先将第一部分得到的hidden_states进行复制,即残差
  • 归一化
  • MLP层
  • 残差相加

在这里插入图片描述

#复制一份
residual = hidden_states
#归一化
hidden_states = self.input_layernorm(hidden_states)

#注意力层
hidden_states, self_attn_weights, present_key_value = self.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    padding_mask=padding_mask,
)
#加上残差
hidden_states = residual + hidden_states

#复制一份
residual = hidden_states
#归一化
hidden_states = self.post_attention_layernorm(hidden_states)
#mlp
hidden_states = self.mlp(hidden_states)
#加上残差
hidden_states = residual + hidden_states

outputs = (hidden_states,)

if output_attentions:
    outputs += (self_attn_weights,)

if use_cache:
        outputs += (present_key_value,)

return outputs
Attention

进行位置编码,让模型更好的捕捉上下文信息

在这里插入图片描述

#经过线性层
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

#多头注意力形状变换
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]

#计算cos、sin
#计算旋转位置嵌入
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

#计算权重
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

#加上掩码
attn_weights = attn_weights + attention_mask
#计算softmax
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = self.o_proj(attn_output)

MOE

MOE层,也就是我们的专家模块,简单来说,主要干的就是通过一个线性层,得到8个专家,从这8个专家中选出最专业的2个,把他们的权重相加,输入到MLP层,得到最终的结果。

  • attention层得到的hidden_states经过控制门(nn.Linear)得到8个输出。(有点像多分类)
  • t通过softmax计算8个输出的概率值
  • 从8个中选择概率值最高的两个专家
  • 概率最高的两个专家进行权重相加,并计算相对概率值
  • 这两个专家输入到MLP层中进行一系列计算得到最后结果

在这里插入图片描述

batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
#这里通过一个线性层,得到8个输出(n_experts),也就是所谓的专家。
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
#这里通过softmax计算8个输出的概率值,如(0.2,0.3,0.0833,0.0833,0.0833,0.0833,0.0833,0.0833)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
#从8个中选择概率值最高的两个专家((0.2,0.3)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
#概率最高的两个专家进行权重相加,并计算相对概率值((0.2,0.3)->(0.4,0.6)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
#初始化最终结果
final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )
#掩码
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

for expert_idx in range(self.num_experts):
    expert_layer = self.experts[expert_idx]
    #通过掩码找到top2的位置
    idx, top_x = torch.where(expert_mask[expert_idx])

    if top_x.shape[0] == 0:
        continue
        top_x_list = top_x.tolist()
        idx_list = idx.tolist()

        #top2对应的向量
        current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
        #经过mlp
        current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]

        #加到final_hidden_states中
        final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

MLP

mlp层的主要作用是应用非线性激活函数和线性投影。

  • 首先将attention层得到的结果经过两个线性层得到gate_proj和up_proj
  • gate_proj经过激活函数,再和up_proj相乘
  • 最后经过一个线性层得到最后的结果

在这里插入图片描述

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

下游任务

因果推理

所谓因果推理,就是回归任务。

在这里插入图片描述

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
文本分类

即分类任务

在这里插入图片描述

self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1318971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android 架构 - MVVM

一、概念 概念基于观察者模式,数据的变化会自动更新到UI。通信 View→ViewModel:View作为观察者,监听ViewModel中数据(LiveData、Flow)的变化从而自动更新UI。 ViewModel→Model:ViewModel调用Model获取数据…

FPGA设计与实战之时钟及时序简介1

文章目录 一、时钟定义二、基本时序三、总结一、时钟定义 我们目前设计的电路以同步时序电路为主,时钟做为电路工作的基准而显得非常重要。 简单的接口电路比如I2C、SPI等,复杂一点接口比如Ethernet的MII、GMII等接口,它们都有一个或多个时钟信号。 那么什么是时钟信号?它…

【华为】文档中命令行约定格式规范(命令行格式规范、命令行行为规范、命令行参数格式、命令行规范)

文章目录 命令行约定格式**粗体&#xff1a;命令行关键字***斜体&#xff1a;命令行参数*[ ]&#xff1a;可选配置{ x | y | ... } 和 [ x | y | ... ]&#xff1a;选项{ x | y | ... }* 和 [ x | y | ... ]*&#xff1a;多选项&<1-n>&#xff1a;重复参数#&#xff…

算法-动态规划

动态规划算法 应用场景-背包问题 介绍 动态规划(Dynamic Programming)算法的核心思想是&#xff1a;将大问题划分为小问题进行解决&#xff0c;从而一步步获取最优解的处理算法动态规划算法与分治算法类似&#xff0c;其基本思想也是将待求解问题分解成若干个子问题&#xff0…

算法:最短路径

文章目录 Dijkstra算法Bellman-Ford算法Floyd-Warshall 本篇总结的是图当中的最短路径算法 Dijkstra算法 单源最短路径问题&#xff1a;给定一个图G ( V &#xff0c; E ) G(V&#xff0c;E)G(V&#xff0c;E)&#xff0c;求源结点s ∈ V s∈Vs∈V到图中每个结点v ∈ V v∈V…

H266/VVC标准的编码结构介绍

概述 CVS&#xff1a; H266的编码码流包含一个或多个编码视频序列&#xff08;Coded Video Swquence&#xff0c;CVS&#xff09;&#xff0c;每个CVS以帧内随机接入点&#xff08;Intra Random Access Point&#xff0c; IRAP&#xff09;或逐渐解码刷新&#xff08;Gradual …

力扣题:数字与字符串间转换-12.18

力扣题-12.18 [力扣刷题攻略] Re&#xff1a;从零开始的力扣刷题生活 力扣题1&#xff1a;38. 外观数列 解题思想&#xff1a;进行遍历然后对字符进行描述即可 class Solution(object):def countAndSay(self, n):""":type n: int:rtype: str""&quo…

小程序静默登录-登录拦截实现方案【全局loginPromis加页面拦截】

实现效果&#xff1a; 用户进入小程序访问所有页面运行onload、onShow、onReady函数时保证业务登录态是有效的 实现难点&#xff1a; 由于小程序的启动流程中&#xff0c;页面级和组件级的生命周期函数都不支持异步阻塞&#xff1b;因此会造成一个情况&#xff0c;app.onLau…

【从零开始学习--设计模式--策略模式】

返回首页 前言 感谢各位同学的关注与支持&#xff0c;我会一直更新此专题&#xff0c;竭尽所能整理出更为详细的内容分享给大家&#xff0c;但碍于时间及精力有限&#xff0c;代码分享较少&#xff0c;后续会把所有代码示例整理到github&#xff0c;敬请期待。 此章节介绍策…

植物分类-PlantsClassification

一、模型配置 一、backbone resnet50 二、neck GlobalAveragePooling 三、head fc 四、loss type‘LabelSmoothLoss’, label_smooth_val0.1, num_classes30, reduction‘mean’, loss_weight1.0 五、optimizer lr0.1, momentum0.9, type‘SGD’, weight_decay0.0001 六、sche…

磁力计LIS2MDL开发(3)----九轴姿态解算

磁力计LIS2MDL开发.3--九轴姿态解算 概述视频教学样品申请完整代码下载使用硬件欧拉角万向节死锁四元数法姿态解算双环PI控制器偏航角陀螺仪解析代码 概述 LIS2MDL 包含三轴磁力计。 lsm6ds3trc包含三轴陀螺仪与三轴加速度计。 姿态有多种数学表示方式&#xff0c;常见的是四元…

【运维笔记】mvware centos挂载共享文件夹

安装mvware-tools 这里用的centos安装 yum install open-vm-tools 设置共享文件夹 依次点击&#xff1a;选项-共享文件夹-总是启用-添加&#xff0c;安装添加向导操作添加自己想共享的文件夹后。成功后即可在文件夹栏看到自己共享的文件夹 挂载文件夹 临时挂载 启动虚拟机&…

lvs-nat部署

LVS负载均衡群集部署——NAT模式 实验环境&#xff1a; 负载调度器&#xff1a;内网关 lvs&#xff0c;ens33&#xff1a;172.16.23.10&#xff1b;外网关&#xff1a;ens36&#xff1a;12.0.0.1 Web服务器1&#xff1a;172.16.23.11 Web服务器2&#xff1a;172.16.23.12 NFS…

【Spring】09 BeanClassLoaderAware 接口

文章目录 1. 简介2. 作用3. 使用3.1 创建并实现接口3.2 配置 Bean 信息3.3 创建启动类3.4 启动 4. 应用场景总结 Spring 框架为开发者提供了丰富的扩展点&#xff0c;其中之一就是 Bean 生命周期中的回调接口。本文将聚焦于其中的一个接口 BeanClassLoaderAware&#xff0c;介…

数据仓库与数据挖掘小结

更加详细的只找得到pdf版本 填空10分 判断并改错10分 计算8分 综合20分 客观题 填空10分 判断并改错10分--错的要改 mooc中的--尤其考试题 名词解释12分 4个&#xff0c;每个3分 经常碰到的专业术语 简答题40分 5个&#xff0c;每道8分 综合 画roc曲线 …

机器视觉技术与应用实战(开运算、闭运算、细化)

开运算和闭运算的基础是膨胀和腐蚀&#xff0c;可以在看本文章前先阅读这篇文章机器视觉技术与应用实战&#xff08;Chapter Two-04&#xff09;-CSDN博客 开运算&#xff1a;先腐蚀后膨胀。开运算可以使图像的轮廓变得光滑&#xff0c;具有断开狭窄的间断和消除细小突出物的作…

C语言数据结构-----二叉树(3)二叉树相关练习题

前言 前面详细讲述了二叉树的相关知识&#xff0c;为了巩固&#xff0c;做一些相关的练习题 文章目录 前言1.某二叉树共有 399 个结点&#xff0c;其中有 199 个度为 2 的结点&#xff0c;则该二叉树中的叶子结点数为&#xff1f;2.下列数据结构中&#xff0c;不适合采用顺序存…

【MySQL】MySQL表的操作-创建查看删除和修改

文章目录 1.创建表2.查看表结构3.修改表4.删除表 1.创建表 语法&#xff1a; CREATE TABLE table_name (field1 datatype,field2 datatype,field3 datatype ) character set 字符集 collate 校验规则 engine 存储引擎;说明&#xff1a; field 表示列名datatype 表示列的类型…

GitHub推荐:下载工具-Motrix

项目地址 GitHub - agalwood/Motrix: A full-featured download manager. 项目简介 Motrix是一个开源的下载工具&#xff0c;支持BT下载、Magnet下载。且下载支持最高64个线程&#xff0c;基本可以说下载速度的上限取决于你的带宽。是一款很不错的下载工具。 项目截图

机器视觉技术与应用实战(Chapter Two-03)

2.5 图像滤波和增强 滤波的作用是&#xff1a;图像中包含需要的信息&#xff0c;也包含我们不感兴趣或需要屏蔽的干扰&#xff0c;去掉这些干扰需要使用滤波。 增强的作用是&#xff1a;通过突出或者抑制图像中某些细节&#xff0c;减少图像的噪声&#xff0c;增强图像的视觉效…