差分注意力,负注意力的引入

news2024/10/12 0:22:41

文章目录

  • Differential Transformer差分注意力,负注意力的引入
    • 相关链接
    • 介绍
      • 初始化函数
      • 多头差分注意力

Differential Transformer差分注意力,负注意力的引入

相关链接

ai-algorithms/README.md at main · Jaykef/ai-algorithms (github.com)

unilm/Diff-Transformer at master · microsoft/unilm (github.com)

介绍

在这里插入图片描述

注意力是非负的,导致在长序列时,有效信息淹没在无关信息的海洋中,因此引入负注意力,着重关注序列中的有效部分。因此一半的注意力头用作负注意力头,注意力权重由这两部分的注意力权重的加权差决定,加权系数可学习。加权系数的初始化值和层数有关。加权系数是通过四个可学习参数重参数化而来

lambda_q1, lambda_k1, lambda_q2, lambda_k2

参数维度

d i m _ h e a d ∗ n u m _ h e a d ∗ 2 = e m b e d _ d i m dim\_head * num\_head *2 = embed\_dim dim_headnum_head2=embed_dim

名称定义举例
dim_headembed // num_heads //232//4//2
proj_q(embed_dim, embed_dim)(32, 32)
proj_k(embed_dim,embed_dim)(32, 32)
proj_v(embed_dim, embed_dim)(32, 32)
proj_out(embed_dim, embed_dim)(32, 32)
Q[N, L, C].view(N, L, 2 *num_heads,dim_head)(1024, 256, 2 *4 , 4)
K[N, L, C].view(N, L, 2 *num_heads, dim_head )(1024, 256, 2 *4 , 4)
V[N, L, C].view(N, L, num_heads, 2*dim_head )(1024, 256, 4 , 2 * 4)
attn_weights[N, 2*num_heads, L, L].view(N, 2,num_heads, L, L ) -> [N, num_heads, L, L](1024, 2 , 4 , 256 , 256 ) ->(1024, 4 , 256 , 256 )
attn[N, num_heads, L, 2*dim_heads]->[N, L, C](1024, 4, 256, 8) -> (1024, 256, 32)

初始化函数

def lambda_init_fn(depth):
    return 0.8 - 0.6 * math.exp(-0.3 * depth)

多头差分注意力

class MultiheadDiffAttn(nn.Module):
    def __init__(
        self,
        embed_dim = 32,
        depth = 0,
        num_heads = 8,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        # num_heads set to half of Transformer's #heads
        self.num_heads = num_heads 
        self.head_dim = embed_dim // num_heads // 2
        self.scaling = self.head_dim ** -0.5
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        self.lambda_init = lambda_init_fn(depth)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

    
    def forward(
        self,
        x,
    ):
        bsz, tgt_len, embed_dim = x.size()
        src_len = tgt_len

        q = self.q_proj(x) #[bsz, tgt_len, embed_dim ]
        k = self.k_proj(x) #[bsz, tgt_len, embed_dim]
        v = self.v_proj(x) #[bsz, tgt_len, embed_dim]

        q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim) #[bsz, tgt_len, 2 * num_heads, head_dim]  embed_dim = head_dim * num_heads
        k = k.view(bsz, src_len, 2 * self.num_heads, self.head_dim) #[bsz, src_len, 2 * num_heads, head_dim]
        v = v.view(bsz, src_len, self.num_heads, 2 * self.head_dim) #[131072, 2, 8, 8] [bsz, tgt_len, num_heads, 2 * head_dim]

        q = q.transpose(1, 2) #[bsz, 2 * num_heads, tgt_len, head_dim] [131072, 16, 2, 4] 
        q *= self.scaling 

        k = k.transpose(1, 2) #[131072, 16, 2, 4]
        v = v.transpose(1, 2) #[131072, 8, 2, 8]
        attn_weights = torch.matmul(q, k.transpose(-1, -2)) #[131072, 16, 2, 2] [bsz, 2 * num_heads, tgt_len, src_len]

        attn_weights = torch.nan_to_num(attn_weights)
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
            attn_weights
        )

        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        #[bsz, 2 * num_heads, tgt_len, src_len] 每一个注意力还是 [bsz, num_heads, tgt_len, src_len]
        attn_weights = attn_weights.view(bsz, self.num_heads, 2, tgt_len, src_len) #[131072, 8, 2, 2, 2] 第一个2是两个差分 
        attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] # 第一个注意力减去第二个注意力 [131072, 8, 2, 2]
        
        #[bsz, num_heads, tgt_len, src_len]
        attn = torch.matmul(attn_weights, v) # [131072, 8, 2, 8]
        attn = attn * (1 - self.lambda_init)
        attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim) #[131072, 2, 32]

        attn = self.out_proj(attn)
        return attn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2206332.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

response和验证码、文件下载操作

目录 Response对象 案例: 1、完成重定向 2、服务器输出字符输出流到浏览器 3、服务器输出字节输出流到浏览器 4、验证码 ServletContext对象 Response对象 功能:设置响应消息 1、设置响应行 格式:HTTP/1.1 200 ok 设置状态码 se…

RabbitMQ 高级特性——死信队列

文章目录 前言死信队列什么是死信常见面试题死信队列的概念:死信的来源(造成死信的原因有哪些)死信队列的应用场景 前言 前面我们学习了为消息和队列设置 TTL 过期时间,这样可以保证消息的积压,那么对于这些过期了的消…

【更新】上市公司企业机构投资者实地调研数据(2013-2023年)

一、测算方式: 参考《会计研究》逯东(2019)老师的做法,考虑投资者实地调研的频率和可能性,设立了下述变量来衡量上市公司接待投资者调研情况: 首先,使用年度范围内接待投资者调研的总次数 ( Visitnmb) 作为…

卸载PLSQL及标准卸载流程

目录 1. 卸载PLSQL2. 删除注册表3. 删除数据信息 1. 卸载PLSQL 等待进度条走完 2. 删除注册表 regedit 右击删除 3. 删除数据信息 由于AppData是隐藏文件,需要勾选隐藏的项目。 重启电脑,PLSQL就卸载成功了。

低代码工单管理app评测,功能与效率解析

预计到2030年,低代码平台市场将达1870亿美元。ZohoCreator助力企业构建定制化软件应用,以建筑行业工作订单管理app为例,简化流程,提升管理效率,降低成本。其用户友好界面、自动化管理、跨平台使用及全面报告功能受企业…

项目优化内容及实战

文章目录 事前思考Prometheus 普罗米修斯概述架构安装及使用 Grafana可视化数据库读写分离实战1-PrometheusGrafanaspringboot 事前思考 需要了解清楚:需要从哪些角度去分析实现?使用了缓存,就需要把缓存命中率数据进行收集;使用…

企业在隔离网环境下如何进行安全又稳定的跨网文件交换?

在数字化时代,企业的数据流通如同血液一般重要。然而,当企业内部实施了隔离网环境,跨网文件交换就成了一个棘手的问题。今天我们将探讨在隔离网环境下,企业面临的跨网文件交换挑战,以及如何通过合规的跨网文件交换系统…

数字电路——触发器1(RS和钟控触发器)

触发器:能够存储一位二进制信息的基本单元电路称触发器(Flip-Flop) 特点: 具有两个能自行保持的稳定状态,用来表示逻辑状态的“0”或“1”。具有一对互补输出。有一组控制(激励、驱动)输入。或许有定时(时钟)端CP(Clock Pulse)。在输入信号…

PostgreSQL 16.4安装以及集群部署

1. 环境准备 1.1 主机环境 主机 IP: 192.24.215.121操作系统: CentOS 9PostgreSQL 版本: 16.4 1.2 从机环境 从机 IP: 192.24.215.122操作系统: CentOS 9PostgreSQL 版本: 16.4 2. 安装 PostgreSQL 16.4 在主从两台机器上都需要安装 PostgreSQL 16.4。 2.1 添加 Postgre…

银行卡基础信息查询 API 对接说明

本文将介绍一种 银行卡基础信息查询 API 对接说明,它可用于银行卡基础信息查询。 接下来介绍下 银行卡基础信息查询 API 的对接说明。 申请流程 要使用 API,需要先到 银行卡基础信息查询 API 对应页面申请对应的服务,进入页面之后&#xf…

Python自定义异常类:实际应用示例之最佳实践

Python自定义异常类:实际应用示例之最佳实践 前言 在软件开发中,合理处理异常是保证程序稳定性的重要环节。虽然 Python 内置了丰富的异常类型,但在处理复杂业务逻辑时,自定义异常类能够使代码更加清晰且具备可扩展性。 本文将…

一个架构师的职业素养:四种常用的权限模型

你好,我是看山。 本文收录在《一个架构师的职业素养》专栏。日拱一卒,功不唐捐。 今天咱们一起聊聊权限系统。 以大家熟知的电商场景举例: 用户可以分为普通用户、VIP用户:我们需要控制不同角色用户的访问范围。比如,京东的PLUS会员,可以进入会员专区,而且能够使用礼金…

ESP32接入扣子(Coze) API使用自定义智能体

使用ESP32接入Coze API实现聊天机器人的教程 本示例将使用ESP32开发板通过WiFi接入 Coze API,实现一个简单的聊天机器人功能。用户可以通过串口向机器人输入问题,ESP32将通过Coze API与智能体进行通信,并返回对应的回复。本文将详细介绍了如…

OpenGL 进阶系列03 - OpenGL实例化渲染来提高性能

目录 一:概述 二:实例化渲染的优点: 三:OpenGL实例化渲染的例子: 一:概述 OpenGL 实例化渲染(Instanced Rendering)是一种渲染技术,可以有效地绘制多个相同对象,而不需要为每个对象单独提交绘制调用。通过这种方式,可以显著提高渲染性能,尤其是在需要绘制大量相…

【每日刷题】Day137

【每日刷题】Day137 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. 1576. 替换所有的问号 - 力扣(LeetCode) 2. 495. 提莫攻击 - 力扣&#xf…

【数据结构与算法】线性表顺序存储结构

文章目录 一.顺序表的存储结构定义1.1定义1.2 图示1.3结构代码*C语言的内存动态分配 二.顺序表基本运算*参数传递2.1建立2.2初始化(InitList(&L))2.3销毁(DestroyList(&L))2.4判断线性表是否为空表(ListEmpty(L))2.5求线性表的长度(ListLength(L))2.6输出线性表(DispLi…

基于GoogleNet深度学习网络的手语识别算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 手语How are you,测试识别结果如下: 手语I am fine,测试识别结果如下: 手…

java入门和Java语法

Java直接运行源代码文件,不会产生HelloWorld.class第二种方法:把模块放在D盘下,然后导入 第三种方法:新建一个模块,然后把内容复制过去 byte l 12; short m l; System.out.println(m); char n a; int reason mn; Sy…

消息摘要算法

算法特点 a) 消息摘要算法/单向散列函数/哈希函数 b) 不同长度的输入,产生固定长度的输出 c) 散列后的密文不可逆 d) 散列后的结果唯一 e) 哈希碰撞 f) 一般用于校验数据完整性、签名sign 由于密文不可逆,所以服务端也无法解密 想要验证&#xf…

解锁机器人视觉与人工智能的潜力,从“盲人机器”改造成有视觉能力的机器人(下)

机器视觉产业链全景回顾 视觉引导机器人生态系统或产业链分为三个层次。 上游(供应商) 该机器人视觉系统的上游包括使其得以运行的硬件和软件提供商。硬件提供商提供工业相机、图像采集卡、图像处理器、光源设备(LED)、镜头、光…