Talking-Heads Attention

news2025/1/10 13:50:06

paper:Talking-Heads Attention

在CaiT这篇文章中,作用采用了talking-heads attention,这里做一下解释。

在原始multi-head self-attention中,各个head的计算是独立进行的,多个head的输出最后concat到一起,然后再经过一个线性变换得到最终的输出。

本文提出了在softmax操作的前后引入跨注意力头维度的线性变换,从而使每个self-attention函数依赖于所有的key和query。

下面分别是timm中普通Attention和TalkingHeadAttention的实现

# class Attention
def forward(self, x: torch.Tensor) -> torch.Tensor:  # (1,197,192)
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    # (1,197,576)->(1,197,3,3,64)->(3,1,3,197,64), (3, batch_size, num_heads, seq_len, head_dim), 3表示qkv
    q, k, v = qkv.unbind(0)  # (1,3,197,64)
    q, k = self.q_norm(q), self.k_norm(k)

    if self.fused_attn:  # False
        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop.p if self.training else 0.,
        )
    else:
        # attn=softmax(qk)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)  # (1,3,197,197)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v  # (1,3,197,64)

    x = x.transpose(1, 2).reshape(B, N, C)  # (1,197,3,64)->(1,197,192)
    x = self.proj(x)  # (1,197,192)
    x = self.proj_drop(x)
    return x

# class TalkingHeadAttn
def forward(self, x):
    B, N, C = x.shape  # (1,196,384)
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)  # (1,196,1152)->(1,196,3,8,48)->(3,1,8,196,48)
    q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]  # (1,8,196,48)

    attn = q @ k.transpose(-2, -1)  # (1,8,196,196)

    attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (1,196,196,8)->(1,196,196,8)->(1,8,196,196)

    attn = attn.softmax(dim=-1)  # (1,8,196,196)

    attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)  # (1,196,196,8)->(1,196,8,8)->(1,8,196,196)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # (1,8,196,48)->(1,196,8,48)->(1,196,384)
    x = self.proj(x)  # (1,196,384)
    x = self.proj_drop(x)
    return x

从下图的对比看的更加清楚,左边是普通的attention,右边是talking-heads attention。左边的输入shape为(1, 197, 192),其中197=196+1是添加了class token,192是特征维度。右边的输入shape为(1, 196, 384),特征维度为384。左边num_heads=3,右边num_heads=8。因为左边的代码来自vision transformer,右边的代码来自CaiT,选择的具体模型variant不同,所以特征维度和head数量也不一样,但不影响。

可以看到,TalkingHeadAttention在计算softmax前后分别引入了一个线性变换self.proj_lself.proj_w,定义分别为self.proj_l = nn.Linear(num_heads, num_heads)self.proj_w = nn.Linear(num_heads, num_heads)。在线性变换前先对输入进行维度变换通过.permute(0, 2, 3 ,1)将num_head维度放到最后,因此线性变换是针对num_head维度的,从而实现跨head的交互,最后再permute回去。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1846873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

idea插件开发之如何获取用户输入的变量名(类变量,局部变量等)

写在前面 比如我们要开发一个变量名称补全功能的插件,此时就需要在用户输入时获取当前的最新输入内容,本文就来看下如何来做。 1:开发 首先我们需要创建一个CompletionContributor的子类,还需要一个CompletionProvider的子类来…

【文档+源码+调试讲解】国风彩妆网站springboot

摘 要 二十一世纪我们的社会进入了信息时代,信息管理系统的建立,大大提高了人们信息化水平。传统的管理方式对时间、地点的限制太多,而在线管理系统刚好能满足这些需求,在线管理系统突破了传统管理方式的局限性。于是本文针对这一…

基于STM32的智能温室控制系统

目录 引言环境准备智能温室控制系统基础代码实现:实现智能温室控制系统 4.1 温湿度传感器数据采集4.2 光照传感器数据采集4.3 控制系统实现4.4 用户界面与数据可视化应用场景:智能温室管理与优化问题解决方案与优化收尾与总结 1. 引言 智能温室控制系…

百度文库AI产品“橙篇”:支持10万字长文生成,开启AI创作新篇章

6月19日,百度文库发布了一款创新产品「橙篇」,这一行业首创的产品集成了10万字长文生成及多模态编辑能力,成为首个实现「查阅创编」一站式AI自由创作平台的里程碑。 百度“橙篇”官网: 地址:橙篇AI - 用橙篇&#xf…

iOS政策解读之一丨App提交审核前注意事项必知

大家好,我是小编阿文。欢迎您关注我们,经常分享有关Android出海,iOS出海,App市场政策实时更新,互金市场投放策略,最新互金新闻资讯等文章,期待与您共航世界之海。 iOS企业出海所面临的主要挑战…

nodejs从基础到实战学习笔记-模块化、包

二、模块化 2.1 什么是模块化 模块化是指解决一个复杂问题时,自顶向下逐层把系统划分成若干模块的过程。对于整个系统来说,模块是可组合、分解和更换的单元。 2.1.1 把代码进行模块化拆分的好处 提高了代码的复用性提高了代码的可维护性可以实现按需…

“神刊”CA再回巅峰!2024年JCR正式发布,共21848本期刊,附完整版EXCEL版下载!

2024 年 6 月 20 日,科睿唯安(Clarivate Analytics)发布了最新的《期刊引证报告》(Journal Citation Reports,JCR),以下简要介绍最新影响因子(IF)情况: 2023年完整版JCR…

华为手机数据恢复,2个技巧介绍,误删文件后的紧急处理

对于华为手机用户来说,有时候我们会因为误操作或意外情况导致手机数据丢失,这无疑是棘手的。但是别担心,本文将为您推荐一些华为手机数据恢复的实用技巧,帮助您在误删文件后迅速找回丢失的数据,最大程度地减少损失。让…

MS17-010(Eternal blue永恒之蓝)漏洞利用+修复方法

目录 一、漏洞简介 漏洞原理 影响版本 二、漏洞复现 三、复现过程 1、扫描局域网内的C段主机(主机发现) 扫描结果: 2.使用MSF的永恒之蓝漏洞模块 3.对主机进行扫描,查看其是否有永恒之蓝漏洞 4.准备攻击 四、漏洞利用 …

多目标跟踪 距离的可视化(有动图)

多目标跟踪 距离的可视化(有动图) flyfish 马氏距离的计算涉及到协方差矩阵的逆,而协方差矩阵的特征值和特征向量决定了数据分布的形状。椭圆的中心是数据的均值向量,椭圆的形状和方向由协方差矩阵的特征向量和特征值决定。椭圆…

中石化加油卡有什么用?

对于有车一族来说,有一张加油卡真的可以省下不少钱 但是像我们这种没车的人,即使得到加油卡也毫无用武之地 久而久之,难免会造成卡过期的情况出现 还好,前两天把我手上堆积了好久的加油卡在收卡云上卖出去了,99折真…

录的视频太大怎么压缩?这几款软件真的很不错!

在数字化时代,视频已成为我们日常生活和工作中不可或缺的一部分。无论是记录生活点滴,还是制作工作汇报,视频都以其直观、生动的特点赢得了我们的青睐。然而,随着视频质量的提升,视频文件的大小也在不断增加&#xff0…

「51媒体」活动会议,展览展会,直播曝光的一种方法

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 我们在做活动会议,或者参加展览展会,需要进行直播的时候,可以通过一键同步多个媒体平台的方法,来扩大曝光,比如一场直播我们可…

【Python/Pytorch 】-- K-means聚类算法

文章目录 文章目录 00 写在前面01 基于Python版本的K-means代码02 X-means方法03 最小二乘法简单理解04 贝叶斯信息准则 00 写在前面 时间演变聚类算法:将时间演变聚类算法用在去噪上,基本思想是,具有相似信号演化的体素具有相似的模型参数…

Redis-事务-基本操作-在执行阶段出错不会回滚

文章目录 1、Redis事务控制命令2、Redis事务错误处理3、Redis事务错误处理,在执行阶段出错不会回滚 1、Redis事务控制命令 127.0.0.1:6379> keys * (empty array) 127.0.0.1:6379> multi OK 127.0.0.1:6379(TX)> set a1 v1 QUEUED 127.0.0.1:6379(TX)>…

Steam怎么购买黄金树之影 购买了黄金树之影怎么下载DLC教程

《艾尔登法环》大型DLC“黄金树幽影”将于6月21日正式上线,为广大玩家带来全新的冒险与挑战。在“黄金树幽影”中,玩家将拥有专属的强化系统。通过收集探索幽影之地获得的“幽影树的碎片”和“灵灰的加护”,不仅可以大幅度提升玩家的攻击力与…

SD卡上的文件删除不了?试试这6种方法!

用户案例 “我需要往32GB的三星Micro SD卡里复制文件,在此之前需要在电脑上删除一些SD卡上的数据来释放空间。但当我尝试按‘Ctrl Delete’删除文件时,文件无法从SD卡上删除。当我尝试格式化SD卡时,Windows提示该磁盘已写保护。这是怎么回事…

数字样机:飞行器状态控制系统仿真

引言:数字样机起源于20世纪90年代,是一种用数字化模型代替实际物理样机进行仿真分析的技术。 传统的飞行器研发流程往往遵循一套特定的循环结构:在设计初期,工程人员需要对飞行器提供一个综合的设计思路(初期蓝图&…

名校介绍|英国六所红砖大学

​近年来由于美国的拒签率增加,很多公派申请者,尤其是CSC资助的访问学者、公派联合培养学生及博士后研究学者,把出国目标改为其它发达国家,尤以英国居多,本文知识人网小编就重点介绍六所英国红砖大学。 我们在“英国大…

1panel + Pbootcms 设置伪静态规则

这里确保我们引用的样式路径是否是这样的&#xff0c;&#xff08;不然可能会设置了伪静态无法加载样式&#xff09; //这种格式在不开起伪静态是可以引入的&#xff0c;一旦开启就不行了,一定要在static 前面加上反斜杠 /<link rel"stylesheet" href"{pbo…