qwen-moe

news2024/12/24 13:19:29

一、定义

  1. qwen-moe 代码讲解, 代码qwen-moe与Mixtral-moe 一样, 专家模块
  2. qwen-moe 开源教程
  3. Mixture of Experts (MoE) 模型在Transformer结构中如何实现,Gate的实现一般采用什么函数? Sparse MoE的优势有哪些?MoE是如何提高模型容量而不显著增加计算负
    担的?

二、实现

  1. qwen-moe 代码讲解
    参考:https://blog.csdn.net/v_JULY_v/article/details/135176583?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-135176583-blog-135046508.235v43pc_blog_bottom_relevance_base4&spm=1001.2101.3001.4242.1&utm_relevant_index=3
import torch
from torch import nn
from torch.nn import functional as F
from transformers.activations import ACT2FN

class Qwen2MoeMLP(nn.Module):
    def __init__(self, config, intermediate_size=None):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        #self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.down_proj(self.gate_proj(x) * self.up_proj(x))


class Qwen2MoeSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob

        # gating
        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
        self.experts = nn.ModuleList(
            [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
        )

        self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        #选取每个token 对应的前k 个专家
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        if self.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)   #权重归一化  确保每个token的专家权重之和为1
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)
        #全为0的张量
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)  #稀疏矩阵

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]                   # 第idx 专家对应的函数
            idx, top_x = torch.where(expert_mask[expert_idx])         #idx 专家,关注的token, top_x 对应第x 个token
            print(expert_idx,top_x.cpu().tolist() )   #专家,处理的token
            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)   专家输入信息:
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)             #取出对应的token信息
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]       #专家输出

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here. 使用.index_add_函数后在指定位置(top_x)加上了指定值(current_hidden_states)
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

        shared_expert_output = self.shared_expert(hidden_states)
        shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

        final_hidden_states = final_hidden_states + shared_expert_output

        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

# 假设的配置
class Config:
    def __init__(self):
        self.num_experts = 8
        self.num_experts_per_tok = 2
        self.norm_topk_prob = True
        self.hidden_size = 2
        self.moe_intermediate_size = 209
        self.shared_expert_intermediate_size = 20

# 检查是否有可用的GPU

device = torch.device("cpu")

# 创建模型实例
config = Config()
model = Qwen2MoeSparseMoeBlock(config).to(device)

input_tensor = torch.randn(1,3,2).to(device)

# 前向传播
output = model(input_tensor)
print(output)

注意:1. 常规思路: 每个token 选择2 个专家, 然后每个token 传入2个专家中,进行处理。----->为了加快推理速度----->关注视角由token 转为专家。在这里插入图片描述
便把关注视角从“各个token”变成了“各个专家”,当然,大部分情况下 token数远远不止下图这5个,而是比专家数多很多。总之,这么一转换,最终可以省掉很多循环。
遍历每个专家,对token 对应的信息整体输入专家模块。

# 【代码块A】routing_weights
# 每行对应1个token,第0列为其对应排位第1的expert、第1列为其对应排位第2的expert,元素值为相应权重
[[0.5310, 0.4690],
 [0.5087, 0.4913],

 [0.5014, 0.4986],
 [0.5239, 0.4761],
 [0.5817, 0.4183],
 [0.5126, 0.4874]]
# 【代码块B】expert_mask[expert_idx]
# 下述两行例子的物理含义为:
# 第一行是“该expert作为排位1的exert存在时,需要处理第9个token;
# 第二行是“该expert作为排位2的expert存在时,需要处理第10、11个token”
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]]
# 【代码块C】idx, top_x = torch.where(expert_mask[expert_idx])
# 以上述expert_mask[expert_idx]样例为例,对应的torch.where(expert_mask[expert_idx])结果如下
idx: [0, 1, 1]
top_x: [9, 10, 11]
idx对应行索引,top_x对应列索引,例如张量expert_mask[expert_idx]中,出现元素1的索引为(0, 9)(1, 10)(1, 11)
从物理含义来理解,top_x实际上就对应着“关乎当前expert的token索引”,第9、第10、第11个token被“路由”导向了当前所关注的expert,通过top_x可以取到“需要传入该expert的输入”,也即第9、第10、第11个token对应的隐向量

因此top_x将作为索引用于从全部token的隐向量hidden_states中取出对应token的隐向量
而idx和top_x也会组合起来被用于从expert权重张量routing_weights中取出对应的权重
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)             #取出top_x的token信息
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]       #专家输出

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here. 使用.index_add_函数后在指定位置(top_x)加上了指定值(current_hidden_states)
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  1. 开源教程
    https://developer.aliyun.com/article/1471903?spm=a2c6h.28954702.blog-index-detail.67.536b4c2d9ZzdBw

  2. Mixture of Experts (MoE) 模型在Transformer结构中如何实现,Gate的实现一般采用什么函数? Sparse MoE的优势有哪些?MoE是如何提高模型容量而不显著增加计算负担的?

self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1718642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【智能算法】三角拓扑聚合优化算法(TTAO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献5.代码获取 1.背景 2024年,S Zhao受到数学相似三角形拓扑结构启发,提出了三角拓扑聚合优化算法(Triangulation Topology Aggregation Optimizer, TTAO)。 2.算…

Unity中的MVC框架

基本概念 MVC全名是Model View Controller 是模型(model)-视图(view)-控制器(controller)的缩写 是一种软件设计规范,用一种业务逻辑、数据、界面显示 分离的方法组织代码 将业务逻辑聚集到一个部件里面,在改进和个性化定制界面及用户交互的同时&#x…

ch4网络层---计算机网络期末复习(持续更新中)

网络层概述 将分组从发送方主机传送到接收方主机 发送方将运输层数据段封装成分组 接收方将分组解封装后将数据段递交给运输层网络层协议存在于每台主机和路由器上 路由器检查所有经过它的IP分组的分组头 注意路由器只有3层(网络层、链路层、物理层) 网络层提供的服务 一…

无人售货机零售业务成功指南:从市场分析到创新策略

在科技驱动的零售新时代,无人售货机作为一种便捷购物解决方案,正逐步兴起,它不仅优化了消费者体验,还显著降低了人力成本,提升了运营效能。开展这项业务前,深入的市场剖析不可或缺,需聚焦消费者…

命令模式(行为型)

目录 一、前言 二、命令模式 三、总结 一、前言 命令模式(Command Pattern)是一种行为型设计模式,命令模式将一个请求封装为一个对象,从而可以用不同的请求对客户进行参数化;对请求排队或记录请求日志,以…

【C++】C++入门2.0

各位读者老爷好,本鼠最近浅学了一点C的入门知识!利用本博客作为笔记的同时也希望得到各位大佬的垂阅! 目录 1. 引用 1.1.引用的概念 1.2.引用的特性 1.3.引用的使用场景 1.4.引用的易错点 1.5.引用的优势 1.6.引用和指针 2.内联函数 …

B端UI设计,演绎高情逸态之妙

B端UI设计,演绎高情逸态之妙

汽车IVI中控开发入门及进阶(二十三):i.MX8

前言: IVI市场的复杂性急剧增加,而TimeToMarket在几代产品中从5年减少到2-3年。Tier1正在接近开放系统的模型(用户可以安装应用程序),从专有/关闭源代码到标准接口/开放源代码,从软件堆栈对系统体系结构/应用层/系统验证和鉴定的完全所有权,越来越依赖第三方中间件和平…

STM32自己从零开始实操03:输出部分原理图

一、继电器电路 1.1指路 延续使用 JZC-33F-012-ZS3 继电器,设计出以小电流撬动大电流的继电器电路。 (提示)电路需要包含:三极管开关电路、续流二极管、滤波电容、指示灯、输出部分。 1.2数据手册重要信息提炼 联系排列&…

Rainbond 携手 TOPIAM 打造企业级云原生身份管控新体验

TOPIAM 企业数字身份管控平台, 是一个开源的IDaas/IAM平台、用于管理账号、权限、身份认证、应用访问,帮助整合部署在本地或云端的内部办公系统、业务系统及三方 SaaS 系统的所有身份,实现一个账号打通所有应用的服务。 传统企业 IT 采用烟囱…

NSS题目练习5

[NISACTF 2022]babyupload 打开后尝试上传php,jpg,png文件都没成功 查看源代码发现有个/source文件 访问后下载压缩包发现有一个python文件 搜索后知道大致意思是,上传的文件不能有后缀名,上传后生成一个uuid,并将uuid…

redis缓存token设置jwt令牌过期时间

登录接口 在上文中 我们已经设置了自定义登录接口自定义拦截器jwt登录校验接口模拟账号登录_jwt自定义拦截器-CSDN博客https://blog.csdn.net/2202_75352238/article/details/138424691?spm1001.2014.3001.5501 但是上文jwt过期时间是由yml文件中配置的,比较不优雅…

Amis源码构建 sdk版本

建议在linux环境下构建(mac环境下也可以),需要用到sh脚本(amis/build.sh)。 Js sdk打包是基于fis进行编译打包的,具体可见fis-conf.js: amis-master源码下载:https://github.com/baidu/amis g…

【OceanBase诊断调优】—— obdiag 工具助力OceanBase数据库诊断调优(DBA 从入门到实践第八期)

1. 前言 昨天给大家分享了【DBA从入门到实践】第八期:OceanBase数据库诊断调优、认证体系和用户实践 中obdiag的部分,今天将其中的内容以博客的形式给大家展开一下,方便大家阅读。 2. 正文 在介绍敏捷诊断工具之前,先说说OceanBa…

【C语言】常见的动态内存的错误

前言 在动态内存函数的使用过程中我们可能会遇到一些错误,这里将常见的错误进行总结。 对NULL解引用 请看以下代码: 可以看到,这时我们的malloc开辟是失败的,所以返回的是空指针NULL,而我们却没有进行检查&#xff0…

使用PNP管控制MCU是否需要复位

这两台用到一款芯片带电池,希望电池还有电芯片在工作的时候插入电源不要给芯片复位,当电池没电,芯片不在工作的时候,插入电源给芯片复位所以使用一个PNP三极管,通过芯片IO控制是否打开复位,当芯片正常工作的…

反激电源压敏电阻设计

压敏电阻的作用:浪涌防护。在电源出现浪涌冲击时,保护核心器件不受到损坏。其实类似于稳压二极管 瞬间的瞬态波 1 压敏电压 单位是,虽然压敏电阻可以吸收很大的浪涌能量,但是不能承受mA以上的持续电流。压敏电压计算公式 2 通流容…

(函数)字符串拼接(C语言)

一、运行结果&#xff1b; 二、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h> # include <string.h>//声明字符串拼接函数&#xff1b; void splice(char a[100], char b[100]);int main() {//初始化变量值&#xff1b;char a[100] …

unity打包的WebGL部署到IIS问题

部署之后会出错&#xff0c;我遇到的有以下几种&#xff1b; 进度条卡住不动 明明已经部署到了IIS上&#xff0c;为什么浏览网页的时候还是过不去或者直接报错。 进度条卡住不动的问题其实就是wasm和data的错误。 此时在浏览器上按F12进入开发者模式查看错误&#xff08;下图…

【前端】Vuex笔记(超详细!!)

最近花了两周时间&#xff0c;完完全全的跟着Vuex官方的视频学完了Vuex并且详详细细的做了笔记&#xff0c;其中总结部分是我对于整个视频课程的总结&#xff0c;视频部分是跟着视频做的笔记&#xff0c;如果总结部分有不懂的话&#xff0c;直接去视频部分查找对应的笔记即可&a…