用PyTorch从零开始编写DeepSeek-V2

news2024/11/24 20:39:19

DeepSeek-V2是一个强大的开源混合专家(MoE)语言模型,通过创新的Transformer架构实现了经济高效的训练和推理。该模型总共拥有2360亿参数,其中每个令牌激活21亿参数,支持最大128K令牌的上下文长度。

在开源模型中,DeepSeek-V2实现了顶级性能,成为最强大的开源MoE语言模型。在MMLU(多模态机器学习)上,DeepSeek-V2以较少的激活参数实现了顶尖的性能。与DeepSeek 67B相比,DeepSeek-V2显著提升了性能,降低了42.5%的训练成本,减少了93.3%的KV缓存,并将最大生成吞吐量提高了5.76倍。

我们这里主要实现DeepSeek的主要改进:多头隐性注意力、细粒度专家分割和共享的专家隔离

架构细节

DeepSeek-V2整合了两种创新架构,我们将详细讨论:

  1. 用于前馈网络(FFNs)的DeepSeekMoE架构。
  2. 用于注意力机制的多头隐性注意力(MLA)。

DeepSeekMoE

在标准的MoE架构中,每个令牌被分配给一个(或两个)专家,每个MoE层都有多个在结构上与标准前馈网络(FFN)相同的专家。这种设置带来了两个问题:指定给令牌的专家将试图在其参数中聚集不同类型的知识,但这些知识很难同时利用;其次,被分配给不同专家的令牌可能需要共同的知识,导致多个专家在各自的参数中趋向于收敛,获取共享知识。

为了应对这两个问题,DeepSeekMoE引入了两种策略来增强专家的专业化:

  1. 细粒度专家分割:为了在每个专家中更有针对性地获取知识,通过切分FFN中的中间隐藏维度,将所有专家分割成更细的粒度。
  2. 共享专家隔离:隔离某些专家作为始终被激活的共享专家,旨在捕获不同上下文中的共同知识,并通过将共同知识压缩到这些共享专家中,减少其他路由专家之间的冗余。

让我们来定义DeepSeekMoE中第t个令牌的专家分配。如果u_t是该令牌的FFN输入,其输出h`_t将会是:

其中𝑁𝑠和𝑁𝑟分别是共享专家和路由专家的数量;FFN(𝑠)*𝑖和FFN(𝑟)*𝑖分别表示𝑖-th共享专家和𝑖-th路由专家。

对于路由专家而言,g_i,t 是第i个路由专家的门控值,s_i,t 是令牌到专家的亲和分数,Topk(., Kr) 包含了Kr个最高的亲和分数,其中Kr是活跃的路由专家的数量。

有了以上的公式,我们就来使用代码实现

门控模型实现:

 classMoEGate(torch.nn.Module):
     def__init__(self, num_experts_per_tok: int, n_routed_experts: int, routed_scaling_factor: int, topk_method: str, n_group: int, topk_group: int, hidden_size: int):
         super().__init__()
         self.top_k=num_experts_per_tok
         self.n_routed_experts=n_routed_experts
         self.routed_scaling_factor=routed_scaling_factor
         self.topk_method=topk_method
         self.n_group=n_group
         self.topk_group=topk_group
         self.weight=torch.nn.Parameter(torch.empty((self.n_routed_experts, hidden_size)))
         torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
     defforward(self, x: torch.Tensor):
         batch, seq_len, h=x.shape
         hidden_states=x.view(-1, h)
         logits=torch.nn.functional.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
         scores=logits.softmax(dim=-1, dtype=torch.float32)
         ifself.topk_method=="greedy":
             topk_weight, topk_idx=torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
         elifself.topk_method=="group_limited_greedy":
             group_scores= (scores.view(batch*seq_len, self.n_group, -1).max(dim=-1).values)
             group_idx=torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]
             group_mask=torch.zeros_like(group_scores)  # [n, n_group]
             group_mask.scatter_(1, group_idx, 1)  # [n, n_group]
             score_mask= (
                 group_mask.unsqueeze(-1)
                 .expand(
                     batch*seq_len, self.n_group, self.n_routed_experts//self.n_group
                 )
                 .reshape(batch*seq_len, -1)
             )  # [n, e]
             tmp_scores=scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]
             topk_weight, topk_idx=torch.topk(
                 tmp_scores, k=self.top_k, dim=-1, sorted=False
             )
         returntopk_idx, topk_weight

MoE

 classMoE(torch.nn.Module):
     def__init__(self, dim: int, routed_scaling_factor: int, topk_method: str, n_group: int, topk_group: int, hidden_dim: int|None=None, n_routed_experts: int=12, num_experts_per_tok: int=4, n_shared_experts: int=2, mlp: str="swiglu"):
         super().__init__()
         self.experts_per_rank=n_routed_experts
         self.num_experts_per_tok=num_experts_per_tok
         self.n_shared_experts=n_shared_experts
         mlp_block=SwiGLU
         self.experts=torch.nn.ModuleList([mlp_block(dim, hidden_dim) foriinrange(n_routed_experts)])
         self.gate=MoEGate(num_experts_per_tok, n_routed_experts, routed_scaling_factor, topk_method, n_group, topk_group, dim)
         self.shared_experts=mlp_block(dim, hidden_dim*n_shared_experts)
         
     defforward(self, x: torch.Tensor):
         identity=x
         orig_shape=x.shape
         topk_idx, topk_weight=self.gate(x)
         x=x.view(-1, x.shape[-1])
         flat_topk_idx=topk_idx.view(-1)
         x=x.repeat_interleave(self.num_experts_per_tok, dim=0)
         y=torch.empty_like(x)
         y=y.type(x.dtype)
         fori, expertinenumerate(self.experts):
             y[flat_topk_idx==i] =expert(x[flat_topk_idx==i]).to(dtype=x.dtype)
         y= (y.view(*topk_weight.shape, -1) *topk_weight.unsqueeze(-1)).sum(dim=1)
         
         y=y.view(*orig_shape)
         output=y+self.shared_experts(identity)
         returnoutput

多头隐性注意力(MLA)

多头隐性注意力(MLA)相较于标准的多头注意力(MHA)实现了更优的性能,并且显著减少了KV缓存,提高了推理效率。与多查询注意力(MQA)和分组查询注意力(GQA)中减少KV头的方法不同,MLA将键(Key)和值(Value)共同压缩成一个潜在向量。

MLA不是缓存键(Key)和值(Value)矩阵,而是将它们联合压缩成一个低秩向量,这使得缓存的项目数量更少,因为压缩维度远小于多头注意力(MHA)中输出投影矩阵的维度。

标准的RoPE(旋转位置嵌入)与上述的低秩KV压缩不兼容。解耦RoPE策略使用额外的多头查询q_t和共享键k_t来实现RoPE。

下面总结了完整的MLA计算过程:

MLA实现

 classMLA(torch.nn.Module):
     def__init__(self, model_args: DeepseekConfig):
         super().__init__()
         d_model=model_args.d_model
         self.num_heads=model_args.num_heads
         self.head_dim=model_args.d_model//model_args.num_heads
         self.attn_dropout=torch.nn.Dropout(model_args.dropout)
         self.res_dropout=torch.nn.Dropout(model_args.dropout)
         self.flash_attn=hasattr(torch.nn.functional, "scaled_dot_product_attention")
         
         self.q_lora_rank=model_args.q_lora_rank
         self.qk_rope_head_dim=model_args.qk_rope_head_dim
         self.kv_lora_rank=model_args.kv_lora_rank
         self.v_head_dim=model_args.v_head_dim
         self.qk_nope_head_dim=model_args.qk_nope_head_dim
         self.q_head_dim=model_args.qk_nope_head_dim+model_args.qk_rope_head_dim
         self.q_a_proj=torch.nn.Linear(d_model, model_args.q_lora_rank, bias=False)
         self.q_a_layernorm=RMSNorm(model_args.q_lora_rank)
         self.q_b_proj=torch.nn.Linear(model_args.q_lora_rank, self.num_heads*self.q_head_dim, bias=False)
         self.kv_a_proj_with_mqa=torch.nn.Linear(d_model,model_args.kv_lora_rank+model_args.qk_rope_head_dim,bias=False,)
         self.kv_a_layernorm=RMSNorm(model_args.kv_lora_rank)
         self.kv_b_proj=torch.nn.Linear(model_args.kv_lora_rank,self.num_heads* (self.q_head_dim-self.qk_rope_head_dim+
             self.v_head_dim),bias=False,)
         self.o_proj=torch.nn.Linear(self.num_heads*self.v_head_dim,d_model, bias=False,)
 
     defforward(self, x: torch.Tensor, mask: torch.Tensor, freqs_cis) ->torch.Tensor:
         batch, seq_len, d_model=x.shape
         q=self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
         q=q.view(batch, seq_len, self.num_heads, self.q_head_dim).transpose(1, 2)
         q_nope, q_pe=torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
         compressed_kv=self.kv_a_proj_with_mqa(x)
         compressed_kv, k_pe=torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
         k_pe=k_pe.view(batch, seq_len, 1, self.qk_rope_head_dim).transpose(1, 2)
         kv= (self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
             .view(batch, seq_len, self.num_heads, self.qk_nope_head_dim+self.v_head_dim)
             .transpose(1, 2))
         k_nope, value_states=torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
         q_pe, k_pe=apply_rope(q_pe, k_pe, freqs_cis)
         k_pe=k_pe.transpose(2, 1)
         q_pe=q_pe.transpose(2, 1)
         query_states=k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
         query_states[:, :, :, : self.qk_nope_head_dim] =q_nope
         query_states[:, :, :, self.qk_nope_head_dim :] =q_pe
         key_states=k_pe.new_empty(batch, self.num_heads, seq_len, self.q_head_dim)
         key_states[:, :, :, : self.qk_nope_head_dim] =k_nope
         key_states[:, :, :, self.qk_nope_head_dim :] =k_pe
         attn_mtx=torch.matmul(query_states, key_states.transpose(2, 3)) /math.sqrt(self.head_dim)
         attn_mtx=attn_mtx+mask[:, :, :seq_len, :seq_len]
         attn_mtx=torch.nn.functional.softmax(attn_mtx.float(), dim=-1).type_as(key_states)
         attn_mtx=self.attn_dropout(attn_mtx)
         output=torch.matmul(attn_mtx, value_states)  # (batch, n_head, seq_len, head_dim)
         output=output.transpose(1, 2).contiguous().view(batch, seq_len, self.num_heads*self.v_head_dim)
         output=self.o_proj(output)
         output=self.res_dropout(output)
         returnoutput

总结

本文详细介绍了DeepSeek-V2语言模型,这是一个强大的开源混合专家(MoE)语言模型,采用创新的架构来提高训练和推理的经济性和效率。DeepSeek-V2采用了两种核心技术:细粒度专家分割和共享专家隔离,这两种策略显著提高了专家的专业化水平。此外,文章还介绍了多头隐性注意力(MLA),这是一种改进的注意力机制,通过低秩键值联合压缩和解耦旋转位置嵌入,优化了模型的存储和计算效率。

除了理论探讨,我们通过编写代码实现DeepSeek-V2,可以更深入地理解其架构和工作原理。可以帮助你账务如何实现先进的混合专家(MoE)模型,还能深化对多头隐性注意力(MLA)和低秩键值压缩等关键技术的理解。通过实践,读者将能够验证理论的有效性,并对模型的性能和效率有直观的认识。

https://avoid.overfit.cn/post/317a967c8dac42ee98f96d8390851476

作者:Zain ul Abideen

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1953015.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C嘎嘎浅谈模板

这篇文章给大家介绍一下c嘎嘎内存管理和模板,那么我们直接进入正题 c/c的程序内存分布 这里的了解一下即可 new和delete的定义和操作 格式:类型* 对象名 new 类型; 数组(对象)定义格式:类型* 对象名 new 类型[元素个数]&…

【安卓开发】【Android】如何进行真机调试【注意事项】

一、所需原料 1、电脑(安装有Android Studio开发工具): 2、安卓操作系统手机:笔者演示所用机型为huawei rongyao50,型号为NTU-AN00: 3、数据线(usb-typeA)一根。 二、操作步骤 1…

支持向量机回归及其应用(附Python 案例代码)

使用支持向量机回归估计房价 让我们看看如何使用支持向量机(SVM)的概念构建一个回归器来估计房价。我们将使用sklearn中提供的数据集,其中每个数据点由13个属性定义。我们的目标是根据这些属性估计房价。 引言 支持向量回归(SV…

IndexError: index 0 is out of bounds for axis 1 with size 0

IndexError: index 0 is out of bounds for axis 1 with size 0 目录 IndexError: index 0 is out of bounds for axis 1 with size 0 【常见模块错误】 【解决方案】 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司&#…

Linux网络工具“瑞士军刀“集合

一、背景 平常我们在进行Linux服务器相关运维的时候,总会遇到一些网络相关的问题。我们可以借助这些小巧、功能强悍的工具帮助我们排查问题、解决问题。 下面结合之前的一些使用经验为大家介绍一下一些经典应用场景下,这个网络命令工具如何使用的。例如怎…

关于Redis(热点数据缓存,分布式锁,缓存安全(穿透,击穿,雪崩));

热点数据缓存: 为了把一些经常访问的数据,放入缓存中以减少对数据库的访问频率。从而减少数据库的压力,提高程序的性能。【内存中存储】成为缓存; 缓存适合存放的数据: 查询频率高且修改频率低 数据安全性低 作为缓存的组件: redis组件 memory组件 e…

大模型算法备案流程最详细说明【流程+附件】

文章目录 一、语料安全评估 二、黑盒测试 三、模型安全措施评估 四、性能评估 五、性能评估 六、安全性评估 七、可解释性评估 八、法律和合规性评估 九、应急管理措施 十、材料准备 十一、【线下流程】大模型备案线下详细步骤说明 十二、【线上流程】算法备案填报…

手把手教小白微信小程序开发(超详细保姆式教程)

注册:微信公众平台 -> 立即注册 ->小程序 AppID(小程序ID) wx05c13b331acc9d01 AppSecret(小程序密钥) 4f8232c7bbd4801e58a166d72e92e529 安装 微信开发者工具 ,扫描就可以登录 设置:右上角设置 ->外观浅色,代理&am…

C++通过进程句柄、进程id或进程名去杀掉进程(附完整源码)

目录 1、通过进程句柄去杀进程 2、通过进程id去杀进程 3、通过进程名去杀进程 C++软件异常排查从入门到精通系列教程(专栏文章列表,欢迎订阅,持续更新...)https://blog.csdn.net/chenlycly/article/details/125529931Windows C++ 软件开发从入门到精通(专栏文章,持续更…

普中51单片机:蜂鸣器的简单使用(十一)

文章目录 引言蜂鸣器的分类工作原理无源蜂鸣器压电式蜂鸣器:电磁式蜂鸣器: 电路符号及应用代码演示——无源蜂鸣器 引言 蜂鸣器是一种常见的电子音响器件,广泛应用于各种电子产品中。它们能够发出不同频率的声音,用于警报、提醒、…

AI软件测试|人工智能测试中对抗样本生成攻略

从医疗诊断、自动驾驶到智能家居,人工智能技术为各个行业领域带来无限可能的同时,挑战也日益显现。特别是在人工智能安全领域,随着恶意攻击和数据欺骗的不断演变,确保AI系统的安全性和可靠性成为亟需解决的重要问题,对…

【游戏制作】使用Python创建一个完整的2048游戏项目

目录 项目运行展示 项目概述 项目目标 项目结构 安装依赖 代码实现 1. 导入库 2. 创建 Game2048 类 3. 设置UI界面 4. 加载二维码图片 5. 创建菜单 6. 游戏逻辑和功能 7. 运行应用 总结 创建一个完整的2048游戏项目 项目运行展示 项目概述 在这个项目中&#xff…

常用sql:删除表中重复的数据

在平常的开发工作中,我们可能经常需要对表进行操作。比如某些数据重复了,那么可能需要删除掉重复的数据,保证数据根据业务字段属性相同的数据只有一条,那么应该如何做呢? 1:新建表:用户详情表 …

for循环计算1~100之间3的倍数的数字之和

你要计算1~100之间的数字先得打印出来1~100之间的数字然后在判断是不是3的倍数然后在打印出数字&#xff0c;代码如下 #include<stdio.h> int main() {int i 0;for (i 1; i < 100; i){if (i % 3 0){printf("%d ", i);}}return 0; }

Intellij IDEA多模块分组 实现move to group

新版本idea&#xff0c;没有了move to group的功能&#xff0c;导致模块很多的时候不能分组。2018版本有。 这个分组是虚拟的&#xff0c;不会在磁盘中实际存在。 要实现这个功能&#xff0c;只需要改modules.xml即可。 步骤 1. 找到配置文件 .idea目录下的moudules.xml 2.…

GeoServer GIS 服务器(geoServer离线地图服务器搭建)

文章目录 引言I GeoServer 安装部署版本选择基于war包进行部署II geoServer配置2.1 geoServer新建工作区2.2 geoServer 新建数据源2.3 geoServer图层发布和图层编辑2.4 指定存储层的坐标系2.5 geoServer图层样式2.6 图层组的创建GIS基础知识GeoServerWMTSEPSGEPSG3857相关的数据…

Cadence学习笔记(十三)--设置边框与异形铺铜

直接导入板框用小眼睛可以看到所有的都是线的属性&#xff1a; 那么如何让它变成板框呢&#xff1f;这里先跳转到下图中的层&#xff1a; 将Z--CPOY这一层变成shape区&#xff1a; 之后用Z--copy: Z--COPY设置如下参数&#xff0c;铺铜内缩20mil: 之后选择长方形铺铜就可以了&…

快醒醒,别睡了!...讲《数据分析pandas库》了—/—<5>

一、 1、修改替换变量值 本质上是如何直接指定单元格的问题&#xff0c;只要能准确定位单元地址&#xff0c;就能够做到准确替换。 1.1 对应数值的替换 具体用法如下&#xff1a; replace方法&#xff1a; df.replace(to_replace None :将被替换的原数值&#xff0c;所有…

matlab6.5免安装版,解压即可用【亲测win10可用】

这个版本是咱第一次学matlab的时候用的处女版&#xff0c;如今看着这个界面依然恍如昨日。为甚要分享这种老掉牙古董matlab版本呢&#xff1f;原因在于一款老古董工具箱 —— geatbx。 这款工具箱采用了古老pcode的加密系统加密&#xff0c;而matlab的pcode加密经过几次迭代&a…

数据库开发:MySQL基础(二)

MySQL基础&#xff08;二&#xff09; 一、表的关联关系 在关系型数据库中&#xff0c;表之间可以通过关联关系进行连接和查询。关联关系是指两个或多个表之间的关系&#xff0c;通过共享相同的列或键来建立连接。常见的关联关系有三种类型&#xff1a;一对多关系&#xff0c;…