代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class BasicExpert(nn.Module):
# 一个 Expert 可以是一个最简单的, linear 层即可
# 也可以是 MLP 层
# 也可以是 更复杂的 MLP 层(active function 设置为 swiglu)
def __init__(self, feature_in, feature_out):
super().__init__()
self.linear = nn.Linear(feature_in, feature_out)
def forward(self, x):
return self.linear(x)
class BasicMOE(nn.Module):
# 创建了一个 BasicMOE 模型,输入特征维度为 6, 输出特征维度为 3, 专家数量为 2。
def __init__(self, feature_in, feature_out, expert_number):
super().__init__()
self.experts = nn.ModuleList(
[
BasicExpert(feature_in, feature_out) for _ in range(expert_number)
]
)
# gate 就是选一个 expert
self.gate = nn.Linear(feature_in, expert_number)
def forward(self, x):
# 两个专家数量, expert_weight 就是两个数字
expert_weight = self.gate(x) # shape 是 (batch, expert_number)
print("expert_weight", expert_weight)
expert_out_list = [
expert(x).unsqueeze(1) for expert in self.experts
] # 里面每一个元素的 shape 是: (batch, ) ??
# concat 起来 (batch, expert_number, feature_out)
# 每个专家输出的特征是3个维度
expert_output = torch.cat(expert_out_list, dim=1)
print("expert_output.size()", expert_output.size())
print("expert_weight", expert_weight.size())
expert_weight = expert_weight.unsqueeze(1) # (batch, 1, expert_nuber)
print("expert_weight", expert_weight.size())
# expert_weight * expert_out_list
output = expert_weight @ expert_output # (batch, 1, feature_out)
return output.squeeze()
def test_basic_moe():
x = torch.rand(2, 6)
# x 是一个形状为 (2, 6) 的输入张量 (2 个样本, 每个样本 6 个特征)。
# 创建了一个 BasicMOE 模型,输入特征维度为 6, 输出特征维度为 3, 专家数量为 2。
basic_moe = BasicMOE(6, 3, 2)
out = basic_moe(x)
# 表示 2 个样本,2 个专家,每个专家输出 3 个特征。
print(out)
test_basic_moe()
代码对应的配图解释: