MOEFeedForward 模块

news2025/3/14 19:36:00

代码

class FeedForward(nn.Module):
    def __init__(self, config: LMConfig):
        super().__init__()
        if config.hidden_dim is None:
            hidden_dim = 4 * config.dim
            hidden_dim = int(2 * hidden_dim / 3)
            config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
        self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))


class MoEGate(nn.Module):
    def __init__(self, config: LMConfig):
        super().__init__()
        self.config = config
        self.top_k = config.num_experts_per_tok
        self.n_routed_experts = config.n_routed_experts

        self.scoring_func = config.scoring_func
        self.alpha = config.aux_loss_alpha
        self.seq_aux = config.seq_aux

        self.norm_topk_prob = config.norm_topk_prob
        self.gating_dim = config.dim
        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        import torch.nn.init as init
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, hidden_states):
        bsz, seq_len, h = hidden_states.shape
        hidden_states = hidden_states.view(-1, h)
        logits = F.linear(hidden_states, self.weight, None)
        if self.scoring_func == 'softmax':
            scores = logits.softmax(dim=-1)
        else:
            raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')

        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)

        if self.top_k > 1 and self.norm_topk_prob:
            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
            topk_weight = topk_weight / denominator

        if self.training and self.alpha > 0.0:
            scores_for_aux = scores
            aux_topk = self.top_k
            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
            if self.seq_aux:
                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
                ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
                ce.scatter_add_(1, topk_idx_for_aux_loss,
                                torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
                    seq_len * aux_topk / self.n_routed_experts)
                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
            else:
                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
                ce = mask_ce.float().mean(0)
                Pi = scores_for_aux.mean(0)
                fi = ce * self.n_routed_experts
                aux_loss = (Pi * fi).sum() * self.alpha
        else:
            aux_loss = 0
        return topk_idx, topk_weight, aux_loss


class MOEFeedForward(nn.Module):
    def __init__(self, config: LMConfig):
        super().__init__()
        self.config = config
        self.experts = nn.ModuleList([
            FeedForward(config)
            for _ in range(config.n_routed_experts)
        ])
        self.gate = MoEGate(config)
        if config.n_shared_experts is not None:
            self.shared_experts = FeedForward(config)

    def forward(self, x):
        identity = x
        orig_shape = x.shape
        bsz, seq_len, _ = x.shape
        # 使用门控机制选择专家
        topk_idx, topk_weight, aux_loss = self.gate(x)
        x = x.view(-1, x.shape[-1])
        flat_topk_idx = topk_idx.view(-1)
        if self.training:
            # 训练模式下,重复输入数据
            x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
            y = torch.empty_like(x, dtype=torch.float16)
            for i, expert in enumerate(self.experts):
                y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)  # 确保类型一致
            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
            y = y.view(*orig_shape)
        else:
            # 推理模式下,只选择最优专家
            y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
        if self.config.n_shared_experts is not None:
            y = y + self.shared_experts(identity)
        self.aux_loss = aux_loss
        return y

    @torch.no_grad()
    def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
        expert_cache = torch.zeros_like(x)
        idxs = flat_expert_indices.argsort()
        tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
        token_idxs = idxs // self.config.num_experts_per_tok
        # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]
        # 当token_idxs=[3, 7, 19, 21, 24, 25,  4,  5,  6, 10, 11, 12...]
        # 意味着当token_idxs[:6] -> [3,  7, 19, 21, 24, 25,  4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......
        for i, end_idx in enumerate(tokens_per_expert):
            start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
            if start_idx == end_idx:
                continue
            expert = self.experts[i]
            exp_token_idx = token_idxs[start_idx:end_idx]
            expert_tokens = x[exp_token_idx]
            expert_out = expert(expert_tokens).to(expert_cache.dtype)
            expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
            # 使用 scatter_add_ 进行 sum 操作
            expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)

        return expert_cache

代码解释

解释一下这段代码的主要组成部分:

  1. FeedForward 类:
  • 实现了一个基础的前馈网络
  • 使用 SwiGLU 激活函数(F.silu(self.w1(x)) * self.w3(x)
  • 包含三个线性层(w1、w2、w3)和一个 dropout 层
  1. MoEGate 类(门控机制):
  • 负责决定每个 token 应该由哪些专家处理
  • 主要步骤:
    1. 计算每个 token 对应每个专家的分数(使用 softmax)
    2. 选择 top-k 个最高分的专家
    3. 计算辅助损失(aux_loss)来平衡专家的使用
  1. MOEFeedForward 类(混合专家系统):
  • 包含多个专家(FeedForward)和一个门控网络(MoEGate)

  • 训练模式:

    1. 使用门控网络选择每个 token 的专家
    2. 将输入数据复制多份,分发给不同专家
    3. 专家并行处理数据
    4. 根据门控权重合并结果
  • 推理模式(moe_infer):

    1. 对专家索引排序,将相同专家的 token 批量处理
    2. 使用 scatter_add_ 将专家输出累加到正确位置
    3. 更高效的推理实现,避免了数据重复
  1. 特殊功能:
  • 支持共享专家(n_shared_experts
  • 实现了专家负载均衡(通过辅助损失)
  • 支持每个 token 选择多个专家(num_experts_per_tok

这是一个典型的 MoE(Mixture of Experts)实现,用于大型语言模型中提高模型容量和计算效率。

示例

# 创建 MoE 实例
dim = 512                    # 输入维度
n_routed_experts = 4         # 专家数量
num_experts_per_tok = 2      # 每个token选择的专家数量

moe = MOEFeedForward(
    dim=dim,
    n_routed_experts=n_routed_experts,
    num_experts_per_tok=num_experts_per_tok,
    hidden_dim=None,         # FFN隐藏层维度,None时自动计算
    dropout=0.1             # dropout比率
)

# 创建示例输入
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, dim)  # 形状: [2, 10, 512]

moe(x)

输出

After gate - topk_idx.shape: torch.Size([20, 2]), topk_weight.shape: torch.Size([20, 2])
After view - x.shape: torch.Size([20, 512]), flat_topk_idx.shape: torch.Size([40])
After repeat_interleave - x.shape: torch.Size([40, 512])
Empty y tensor shape: torch.Size([40, 512])
Expert 0 - input shape: torch.Size([9, 512])
Expert 0 - output shape: torch.Size([9, 512])
Expert 1 - input shape: torch.Size([13, 512])
Expert 1 - output shape: torch.Size([13, 512])
Expert 2 - input shape: torch.Size([11, 512])
Expert 2 - output shape: torch.Size([11, 512])
Expert 3 - input shape: torch.Size([7, 512])
Expert 3 - output shape: torch.Size([7, 512])
Before view - y.shape: torch.Size([40, 512])
topk_weight.shape: torch.Size([20, 2])
After view and sum - y.shape: torch.Size([20, 512])
Final y.shape: torch.Size([2, 10, 512])

相应的torch函数

import torch
# empty: 创建未初始化的张量
x = torch.empty((2, 3))  # 创建形状为 2x3 的未初始化张量

# zeros_like: 创建与输入相同形状的全零张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.zeros_like(a)  # 创建形状为 2x2 的全零张量
print(b)  # tensor([[0, 0], [0, 0]])
tensor([[0, 0],
        [0, 0]])
import torch.nn.functional as F
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
# view: 改变张量形状
y = x.view(-1)  # 展平为一维
print(y)  # tensor([1, 2, 3, 4, 5, 6, 7, 8])

# -1 表示自动计算该维度大小
z = x.view(-1, 2)  # 重塑为 4x2
print(z)  # tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
tensor([1, 2, 3, 4, 5, 6, 7, 8])
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
# linear: 线性变换 y = xA^T + b
input = torch.randn(2, 3)  # 2个样本,每个3维
weight = torch.randn(4, 3)  # 输出4维
output = F.linear(input, weight)  # 形状变为 [2, 4]

# softmax: 将数值转换为概率分布
logits = torch.tensor([1.0, 2.0, 3.0])
probs = F.softmax(logits, dim=0)
print(probs)  # tensor([0.0900, 0.2447, 0.6652])
tensor([0.0900, 0.2447, 0.6652])
# 找出最大的k个值及其索引
x = torch.tensor([1, 5, 2, 8, 3])
values, indices = torch.topk(x, k=2)
print(values)   # tensor([8, 5])
print(indices)  # tensor([3, 1])
tensor([8, 5])
tensor([3, 1])
x = torch.tensor([1, 2, 3])
# 每个元素重复2次
y = x.repeat_interleave(2)
print(y)  # tensor([1, 1, 2, 2, 3, 3])
tensor([1, 1, 2, 2, 3, 3])
# 统计每个数字出现的次数
x = torch.tensor([1, 1, 2, 3, 1, 2])
counts = x.bincount()
print(counts)  # tensor([0, 3, 2, 1])  # 0出现0次,1出现3次,2出现2次,3出现1次
tensor([0, 3, 2, 1])
# 在指定位置累加值
src = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)  # 指定数据类型为 float
index = torch.tensor([[0, 1], [0, 1]])
out = torch.zeros(2, 2, dtype=torch.float)  # 确保与 src 的数据类型相同
out.scatter_add_(0, index, src)
print(out) 
tensor([[4., 0.],
        [0., 6.]])
# 返回排序后的索引
x = torch.tensor([3, 1, 4, 1, 5])
indices = x.argsort()
print(indices)  # tensor([1, 3, 0, 2, 4])  # 最小值在位置1和3,然后是0,2,4
tensor([1, 3, 0, 2, 4])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2315039.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

政策助力,3C 数码行业数字化起航

政策引领,数字经济浪潮来袭 在当今时代,数字经济已成为全球经济发展的核心驱动力,引领着新一轮科技革命和产业变革的潮流。我国深刻洞察这一发展趋势,大力推进数字化经济发展战略,为经济的高质量发展注入了强大动力。 …

MySQL数据库复制

文章目录 MySQL数据库复制一、复制的原理二、复制的搭建1.编辑配置文件2.在主库上创建复制的用户3.获取主库的备份4.基于从库的恢复5.建立主从复制6.开启主从复制7.查看主从复制状态 MySQL数据库复制 MySQL作为非常流行的数据库,支撑它如此出彩的因素主要有两个&am…

101.在 Vue 3 + OpenLayers 使用 declutter 避免文字标签重叠

1. 前言 在使用 OpenLayers 进行地图开发时,我们经常需要在地图上添加点、线、区域等图形,并给它们附加文字标签。但当地图上的标注较多时,文字标签可能会发生重叠,导致用户无法清晰地查看地图信息。 幸运的是,OpenL…

uniapp移动端图片比较器组件,仿英伟达官网rtx光追图片比较器功能

组件下载地址:https://ext.dcloud.net.cn/plugin?id22609 已测试h5和微信小程序,理论支持全平台 亮点: 简单易用 使用js计算而不是resize属性,定制化程度更高 组件挂在后可播放指示线动画,提示用户可以拖拽比较图片…

深度学习与大模型-矩阵

矩阵其实在我们的生活中也有很多应用,只是我们没注意罢了。 1. 矩阵是什么? 简单来说,矩阵就是一个长方形的数字表格。比如你有一个2行3列的矩阵,可以写成这样: 这个矩阵有2行3列,每个数字都有一个位置&a…

搭建基于chatgpt的问答系统

一、语言模型,提问范式与 Token 1.语言模型 大语言模型(LLM)是通过预测下一个词的监督学习方式进行训练的,通过预测下一个词为训练目标的方法使得语言模型获得强大的语言生成能力。 a.基础语言模型 (Base LLM&…

LuaJIT 学习(2)—— 使用 FFI 库的几个例子

文章目录 介绍Motivating Example: Calling External C Functions例子:Lua 中调用 C 函数 Motivating Example: Using C Data StructuresAccessing Standard System FunctionsAccessing the zlib Compression LibraryDefining Metamethods for a C Type例子&#xf…

解锁 AI 开发的无限可能:邀请您加入 coze-sharp 开源项目

大家好!今天我要向大家介绍一个充满潜力的开源项目——coze-sharp!这是一个基于 C# 开发的 Coze 客户端,旨在帮助开发者轻松接入 Coze AI 平台,打造智能应用。项目地址在这里:https://github.com/zhulige/coze-sharp&a…

全面解析与实用指南:如何有效解决ffmpeg.dll丢失问题并恢复软件正常运行

在使用多媒体处理软件或进行视频编辑时,你可能会遇到一个常见的问题——ffmpeg.dll文件丢失。这个错误不仅会中断你的工作流程,还可能导致软件无法正常运行。ffmpeg.dll是FFmpeg库中的一个关键动态链接库文件,负责处理视频和音频的编码、解码…

Python----计算机视觉处理(opencv:像素,RGB颜色,图像的存储,opencv安装,代码展示)

一、计算机眼中的图像 像素 像素是图像的基本单元,每个像素存储着图像的颜色、亮度和其他特征。一系列像素组合到一起就形成 了完整的图像,在计算机中,图像以像素的形式存在并采用二进制格式进行存储。根据图像的颜色不 同,每个像…

小米路由器SSH下安装DDNS-GO

文章目录 前言一、下载&安装DDNS-GO二、配置ddns-go设置开机启动 前言 什么是DDNS? DDNS(Dynamic Domain Name Server)是动态域名服务的缩写。 目前路由器拨号上网获得的多半都是动态IP,DDNS可以将路由器变化的外网I…

go语言zero框架拉取内部平台开发的sdk报错的修复与实践

在开发过程中,我们可能会遇到由于认证问题无法拉取私有 SDK 的情况。这种情况常发生在使用 Go 语言以及 Zero 框架时,尤其是在连接到私有平台,如阿里云 Codeup 上托管的 Go SDK。如果你遇到这种错误,通常是因为 Go 没有适当的认证…

手机屏幕摔不显示了,如何用其他屏幕临时显示,用来导出资料或者清理手机

首先准备一个拓展坞 然后 插入一个外接的U盘 插入鼠标 插入有数字小键盘区的键盘 然后准备一根高清线,一端链接电脑显示器,一端插入拓展坞 把拓展坞的连接线,插入手机充电口(可能会需要转接头) 然后确保手机开机 按下键盘…

工业三防平板AORO-P300 Ultra,开创铁路检修与调度数字化新范式

在现代化铁路系统的庞大网络中,其设备维护与运营调度的精准性直接影响着运输效率和公共安全。在昼夜温差大、电磁环境复杂、震动粉尘交织的铁路作业场景中,AORO-P300 Ultra工业三防平板以高防护标准与智能化功能体系,开创了铁路行业移动端数字…

LInux基础--apache部署网站

httpd的安装 yum -y install httpdhttpd的使用 启动httpd systemctl enable --now httpd使用enable --now 进行系统设置时,会将该服务设置为开机自启并且同时开启服务 访问httpd 创建虚拟主机 基于域名 在一台主机上配置两个服务server1和server2,其…

Linux内核套接字以及分层模型

一、套接字通信 内核开发工程师将网络部分的头文件存储到一个专门的目录include/net中,而不是存储到标准位置include/linux。 计算机之间通信是一个非常复杂的问题: 如何建立物理连接?使用什么样的线缆?通信介质有那些限制和特殊…

Linux《基础开发工具(中)》

在之前的Linux《基础开发工具(上)》当中已经了解了Linux当中到的两大基础的开发工具yum与vim;了解了在Linux当中如何进行软件的下载以及实现的基本原理、知道了编辑器vim的基本使用方式,那么接下来在本篇当中将接下去继续来了解另…

使用1Panel一键搭建WordPress网站的详细教程(全)

嘿,各位想搭建自己网站的朋友们!今天我要跟大家分享我用1Panel搭建WordPress网站的全过程。说实话,我之前对服务器运维一窍不通,但通过这次尝试,我发现原来建站可以这么简单!下面是我的亲身经历和一些小技巧…

uni-app学习笔记——自定义模板

一、流程 1.这是一个硬性的流程,只要按照如此程序化就可以实现 二、步骤 1.第一步 2.第二步 3.第三步 4.每一次新建页面,都如第二步一样;可以选择自定义的模版(vue3Setup——这是我自己的模版),第二步的…

数据结构——顺序表seqlist

前言:大家好😍,本文主要介绍了数据结构——顺序表部分的内容 目录 一、线性表的定义 二、线性表的基本操作 三.顺序表 1.定义 2. 存储结构 3. 特点 四 顺序表操作 4.1初始化 4.2 插入 4.2.1头插 4.2.2 尾插 4.2.3 按位置插 4.3 …