Mixtral模型解读

news2025/1/12 18:11:25

Mixtral 8x7B(Mistral MoE)

1.Mistral 7B模型

Mistral 7B模型与Llama2 7B模型结构整体上是相似的,其结构参数如下所示。

image-20240302225715252

细节上来说,他有两点不同。

1.1SWA(Sliding Window Attention)

​ 一般的Attention来说,是Q与KV-Cache做内积,然后求出新的KV。其中KV.shape = [batch_size, num_heads, seq_len, dim]Llama2中的GQA是在多头上做文章,即让多组Q共享一组KV。而SWA是在seq_len上做文章,即限定attention的视野范围在Window中。

​ 这是原本一个7x7的mask矩阵

image-20240302230327377

​ 而这时Slide Window=3的mask矩阵

image-20240302230459982

if input_ids.shape[1] > 1:
    # seqlen推理时在prompt阶段为n,在generation阶段为1
    seqlen = input_ids.shape[1]
    # mask在推理时也只在prompt阶段有,
    #定义一个全1方阵
    tensor = torch.full((seqlen, seqlen),fill_value=1)
    # 上三角部分全为0
    mask = torch.tril(tensor, diagonal=0).to(h.dtype)
    # 这里代码diagonal应该等于(-self.args.sliding_window+1)才能满足window size为 sliding_window
    mask = torch.triu(mask, diagonal=-self.args.sliding_window)
    mask = torch.log(mask)

而在generation阶段,因为是自回归生成所以mask起不到作用。此时mistral则使用了RotatingBufferCache来实现此操作。

image-20240302230629896

# The cache is a rotating buffer
# positions[-self.sliding_window:] 取最后w个位置的索引,取余
# [None, :, None, None]操作用于扩维度[1,w,1,1]
scatter_pos = (positions[-self.sliding_window:] % self.sliding_window)[None, :, None, None]
# repeat操作repeat维度 [bsz, w, kv_head, head_dim]
scatter_pos = scatter_pos.repeat(bsz, 1, self.n_kv_heads, self.args.head_dim)
# src取[:,-w,:,:] 所以src.shape=[bsz,w,kv_head,head_dim]
# 根据scatter_pos作为index 将src写入cache
self.cache_k[:bsz].scatter_(dim=1, index=scatter_pos, src=xk[:, -self.sliding_window:])
self.cache_v[:bsz].scatter_(dim=1, index=scatter_pos, src=xv[:, -self.sliding_window:])

自然,我们会有疑问。只让Q与前面Window Size的KV计算Attention,不会影响最终的预测精度吗?

image-20240302230803700

官方给出了这张图片。具体来说窗口确实限制了attention的视野范围。但是并不是完全无法观察到窗口外的信息。

例子:

  1. 第1层

    • "A"只能看到自己。
    • “B"看到"A"和"B”。
    • “C"看到"B"和"C”。
    • “D"看到"C"和"D”。
    • “E"看到"D"和"E”。

    在第1层结束时,“E"直接关注了"D”,间接接收了"C"的影响。

  2. 第2层

    • "A"依然只能看到自己。
    • “B"现在可以看到经过第1层处理的"A"和"B”。
    • “C"可以看到经过第1层处理的"B"和"C”,这意味着"C"现在间接包含了"A"的信息。
    • “D"可以看到经过第1层处理的"C"和"D”,间接包含了"B"的信息。
    • “E"可以看到经过第1层处理的"D"和"E”,间接包含了"C"的信息,而"C"又包含了"A"的信息。

    在第2层结束时,“E"直接关注了"D”(包含"C"的信息),间接关注了"C"(包含"A"和"B"的信息)。

    ​ 综上所述,对于 l a y e r t layer_t layert而言虽然 Q H Q_H QH只能直接与 t o k e n s − F , G , H tokens - {F,G,H} tokensF,G,H直接进行注意力机制计算,但是却可以间接与更早 t o k e n s − G , F , E , D . . . tokens - {G,F,E,D...} tokensG,F,E,D...参与注意力机制运算,以此类推,只要层数足够大,配合这种传递方式就可以覆盖整个序列。论文中还举例说明,对于一个序列长度是16k,Window Size为4K的SWA,只需要四层,最后一个token就能看到之前的全部token信息。

1.2MoE

MoE简单来说就是让一个网络模型有多条分支,每条分支代表一个Expert(专家),每个Expert有擅长的领域,当面对不同的具体任务,可以通过一个门控单元来选择哪或哪几个Expert进行计算。当然在训练MoE模型时也要注意各个Experts负载均衡,防止赢者通吃,达不到想要的目的。
在这里插入图片描述
​ 左边是Llama结构,右边是Mixtral。

"""
			代码中 epxerts 是 self.feed_forward = MoeLayer(
                experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
                gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
                moe_args=args.moe)
                
                gate 是 linear(dim, num_experts)
""" 
class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.args = moe_args
        
    def forward(self, inputs: torch.Tensor):
        # inputs.shape = [B*L, D]
        gate_logits = self.gate(inputs)
        # 从 gate_logits 中选出, 得分最高的 topk 个专家
        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
        # 确保每个输入分配给专家的权重和为1,并将结果转换回输入的数据类型
        weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
        """
        	假设 softmax 后 weights 为 [0.2, 0.1, 0.7], selected_experts 为 [1,3,4] 表明
        	模型认为第一个,第三个,第四个专家的意见占比分别 .2, .1, .7
        """
        results = torch.zeros_like(inputs)
        for i, expert in enumerate(self.experts):
            # torch.where : 给出 i 在 selected_experts 中的行和列
            batch_idx, nth_expert = torch.where(selected_experts == i)
            # 加权和,将输入放入专家网络 与 权值加权求和
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])
        return results

绝大部分参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1486051.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

23端口登录的Telnet命令+传输协议FTP命令

一、23端口登录的Telnet命令 Telnet是传输控制协议/互联网协议(TCP/IP)网络(如Internet)的登录和仿真程序,主要用于Internet会话。基本功能是允许用户登录进入远程主机程序。 常用的Telnet命令 Telnet命令的格式为&…

基础算法(四)(递归)

1.递归算法的介绍: 概念:递归是指函数直接或间接调用自身的过程。 解释递归的两个关键要素: 基本情况(递归终止条件):递归函数中的一个条件,当满足该条件时,递归终止,避…

C++11中的auto、基于范围的for循环、指针空值nullptr

目录 auto关键字 使用原因 历史背景 C11中的auto auto的使用案例 auto 指针/引用 同一行定义多个变量 typeid关键字 基于范围的for循环 范围for的语法 范围for的使用条件 指针空值nullptr C98中的指针空值 C11中的指针空值 auto关键字 使用原因 随着程序越…

Decoupled Knowledge Distillation解耦知识蒸馏

Decoupled Knowledge Distillation解耦知识蒸馏 现有的蒸馏方法主要是基于从中间层提取深层特征,而忽略了Logit蒸馏的重要性。为了给logit蒸馏研究提供一个新的视角,我们将经典的KD损失重新表述为两部分,即目标类知识蒸馏(TCKD&a…

JavaSec 基础之五大不安全组件

文章目录 不安全组件(框架)-Shiro&FastJson&Jackson&XStream&Log4jLog4jShiroJacksonFastJsonXStream 不安全组件(框架)-Shiro&FastJson&Jackson&XStream&Log4j Log4j Apache的一个开源项目,是一个基于Java的日志记录框架。 历史…

python学习笔记------元组

元组的定义 定义元组使用小括号,且使用逗号隔开各个数据,数据是不同的数据类型 定义元组字面量:(元素,元素,元素,......,元素) 例如:(1,"hello") 定义元组变量:变量名称(元素,元素,元素,......,元素)…

哈希表是什么?

一、哈希表是什么? 哈希表,也称为散列表,是一种根据关键码值(Key value)直接进行访问的数据结构。它通过把关键码值映射到表中一个位置来访问记录,从而加快查找速度。这个映射函数叫做散列函数&#xff08…

C#与VisionPro联合开发——单例模式

单例模式 单例模式是一种设计模式,用于确保类只有一个实例,并提供一个全局访问点来访问该实例。单例模式通常用于需要全局访问一个共享资源或状态的情况,以避免多个实例引入不必要的复杂性或资源浪费。 Form1 的代码展示 using System; usi…

初阶数据结构之---栈和队列(C语言)

引言 在顺序表和链表那篇博客中提到过,栈和队列也属于线性表 线性表: 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构。线性表在逻辑上是线性结构,也就是说是连…

c++之拷贝构造和赋值

如果一个构造函数中的第一个参数是类本身的引用,或者是其他的参数都有默认值,则该构造函数为拷贝构造函数。 那么什么是拷贝构造呢?利用同类对象构造一个新对象。 1,函数名和类必须同名。 2,没有返回值。 3&#x…

差分题练习(区间更新)

一、差分的特点和原理 对于一个数组a[],差分数组diff[]的定义是: 对差分数组做前缀和可以还原为原数组: 利用差分数组可以实现快速的区间修改,下面是将区间[l, r]都加上x的方法: diff[l] x; diff[r 1] - x;在修改完成后,需要做前缀和恢复…

4.关联式容器

关联式container STL中一些常见的容器: 序列式容器(Sequence Containers): vector(动态数组): 动态数组,支持随机访问和在尾部快速插入/删除。list(链表)&am…

奇舞周刊第521期:“一切非 Rust 项目均为非法”

奇舞推荐 ■ ■ ■ 拜登:“一切非 Rust 项目均为非法” 科技巨头要为Coding安全负责。这并不是拜登政府对内存安全语言的首次提倡。“程序员编写代码并非没有后果,他们的⼯作⽅式于国家利益而言至关重要。”白宫国家网络总监办公室(ONCD&…

Python3零基础教程之数学运算专题进阶

大家好,我是千与编程,今天已经进入我们Python3的零基础教程的第十节之数学运算专题进阶。上一次的数学运算中我们介绍了简单的基础四则运算,加减乘除运算。当涉及到数学运算的 Python 3 刷题使用时,进阶课程包含了许多重要的概念和技巧。下面是一个简单的教程,涵盖了一些常…

NOC2023软件创意编程(学而思赛道)python初中组决赛真题

目录 下载原文档打印做题: 软件创意编程 一、参赛范围 1.参赛组别:小学低年级组(1-3 年级)、小学高年级组(4-6 年级)、初中组。 2.参赛人数:1 人。 3.指导教师:1 人(可空缺)。 4.每人限参加 1 个赛项。 组别确定:以地方教育行政主管部门(教委、教育厅、教育局) 认…

嵌入式驱动学习第一周——linux的休眠与唤醒

前言 本文介绍进程的休眠与唤醒。 嵌入式驱动学习专栏将详细记录博主学习驱动的详细过程,未来预计四个月将高强度更新本专栏,喜欢的可以关注本博主并订阅本专栏,一起讨论一起学习。现在关注就是老粉啦! 行文目录 前言1. 阻塞和非阻…

Doris实战——美联物业数仓

目录 一、背景 1.1 企业背景 1.2 面临的问题 二、早期架构 三、新数仓架构 3.1 技术选型 3.2 运行架构 3.2.1 数据模型 纵向分域 横向分层 数据同步策略 3.2.2 数据同步策略 增量策略 全量策略 四、应用实践 4.1 业务模型 4.2 具体应用 五、实践经验 5.1 数据…

【Java EE】线程安全的集合类

目录 🌴多线程环境使用 ArrayList🎍多线程环境使⽤队列🍀多线程环境使⽤哈希表🌸 Hashtable🌸ConcurrentHashMap ⭕相关面试题🔥其他常⻅问题 原来的集合类, 大部分都不是线程安全的. Vector, Stack, HashT…

EndNote 21:文献整理与引用,一键轻松搞定 mac/win版

EndNote 21是一款功能强大的文献管理软件,专为学术研究者、学生和教师设计。它提供了全面的文献管理解决方案,帮助用户轻松整理、引用和分享学术文献。 EndNote 21软件获取 EndNote 21拥有直观的用户界面和强大的文献检索功能,用户可以轻松地…

昇腾ACL应用开发之硬件编解码dvpp

1.前言 在我们进行实际的应用开发时,都会随着对一款产品或者AI芯片的了解加深,大家都会想到有什么可以加速预处理啊或者后处理的手段?常见的不同厂家对于应用开发的时候,都会提供一个硬件解码和硬件编码的能力,这也是抛…