注意力机制简介

news2024/11/17 16:22:35

为了减少计算复杂度,通过借鉴生物神经网络的一些机制,我们引入了局部连接、权重共享以及汇聚操作来简化神经网络结构。神经网络中可以存储的信息量称为网络容量。一般来讲,利用一组神经元来存储信息的容量和神经元的数量以及网络的复杂度成正比。如果要存储越多的信息,神经元数量就要越多或者网络要越复杂,进而导致神经网络的参数成倍的增加。我们人脑的生物神经网络同样存在着容量问题,人脑中的工作记忆大概只有几秒钟的时间,类似于循环神经网络中的隐状态。人脑在有限的资源下,并不能同时处理这些过载的信息。大脑就是通过两个重要机制:注意力和记忆机制,来解决信息过载问题的。

基于显著性注意力

聚焦式注意力

注意力一般分为两种:一种是自上而下的有意识的注意力,称为聚焦式(focus)注意力,如上图,我们如果主观上要去看一本书,那在当下的场景中,可能会直接去搜寻书籍;另一种是自下而上的无意识的注意力,称为基于显著性的注意力,如上图,如果我们没有任何目的,随意的一看,可能最容易吸引我们注意的就是红色的茶杯。

在神经网络中,我们可以把最大池化(max pooling),门控等机制看作是自下而上的基于显著性的注意力机制,因为这些操作并没有主动去搜索信息。而我们这里讨论的注意力机制可以看作是一种自上而下的聚焦式注意力机制。用X=[x1,x2,...,xn]表示n个输入信息,为了节省计算资源,不需要将所有的n个输入信息都输入到神经网络进行计算,只需要从X中选择一些和任务相关的信息输入给神经网络。注意力机制的计算可以分为两步:一是在所有输入信息上计算注意力分布,二是根据注意力分布来计算输入信息的加权平均。

注意力机制其实可以理解为求解相似度。这点网上有一个视频讲解的很好。

假设现在有一个从腰围到体重的映射,我们成腰围为key,体重为value。对应效果如下:

key:51——》value:40

key:56——》value:43

key:58——》value:48

那么,如果现在有一个query=57,value该怎么求?

一个最自然的想法就是,57是56和58的平均数,所以对应的value也应该是43和48的平均数,f(q)=(v2+v3)/2,这里因为57距离56和58非常近,我们会非常“注意”它们,所以我们分给56和58的注意力权重为0.5,相当于我们假设这里面存在一个映射,也就是函数f(q),假设用α(q,ki)来表示q与对应ki的注意力权重,那么value的预测值展开来就是f(q)=α(q,k1)v1 + α(q,k2)v2 + α(q,k3)v3,这里我们认为α(q,k1)=0,因为距离较远,所以我们没有考虑,而α(q,k2)和α(q,k3)都为0.5,这就得到了我的结果。

不过这种算法没有考虑到其它数据可能带来的影响,实际的情况可能远比求平均数复杂。所以,更一般的,我们应该来计算注意力权重α(q,ki),关键是怎么计算。一般来说,我们会设置一个注意力打分函数,然后对注意力打分函数的结果来进行softmax操作,得到注意力权重,用公式表示就是:

这里面的a(q,ki)就是注意力打分函数,一般有加性模型、点积模型、缩放点击模型和双线性模型,他们的公式如下:

实际应用中,我们一般选择缩放点击模型来作为打分函数。其中W,U,v为可学习的参数,d为输入信息的维度。注意力分布αi可以解释为在给定任务相关的查询q时,第i个信息受关注的程度。

如果q,k,v是多维的也是一样的,我们可以用矩阵来表示,并以Q,K,V来命名查询和键值对,采用缩放点积模型,公式如下:

当Q,K,V是同一个矩阵时,这就是自注意力机制了,Q=K=V=X,写成公式就是:

这里面有三个可以训练的参数矩阵WQ,WK,WV,写成公式就是:

这就是自注意力机制的公式。

李沐的《动手学深度学习》中有个例子比较好的说明了注意力机制作用,就是Nadaraya-Waston回归,下面我们也来实现一下:

给定成对的输入输出数据集{(x1,y1),...,(xn,yn)},学习一个函数f来预测任意新输入x的输出y_hat = f(x)

假设真实函数是:

我们给数据加一个随机扰动,生成训练数据:

n_train = 50 # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train)*5) # 排序后的训练样本,生成50个[0,1)的数据再乘以5,等于生成50个[0,5)之间的数据

# 定义一个真实函数
def f(x):
    return 2*torch.sin(x)+x**0.8
    
y_train = f(x_train) + torch.normal(0.0,0.5,(n_train,)) # 训练样本的输出,加上了一个扰动
x_test = torch.arange(0,5,0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出

定义一个画图函数:

def plot_kernel_reg(y_hat):
    plt.plot(x_test, y_truth)
    plt.plot(x_test, y_hat)
    plt.xlabel('x')
    plt.ylabel('y')
    plt.legend(['Truth', 'Pred'])
    plt.plot(x_train, y_train, 'o', alpha=0.5)
    plt.show()

1.直接取训练数据的平均值来预测:

y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)
plt.show()

结果肯定是不理想的,就是一条直线。

2.非参数注意力:

非参数注意力其实就是不需要训练的注意力计算,直接训练数据和测试数据直接的差距来计算注意力。

#每⼀⾏都包含着相同的测试输⼊(例如:同样的查询)
x_repeat =x_test.repeat_interleave(n_train).reshape((-1,n_train))
print(x_repeat)
print(x_repeat.shape)

# 输出
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.1000, 0.1000, 0.1000,  ..., 0.1000, 0.1000, 0.1000],
        [0.2000, 0.2000, 0.2000,  ..., 0.2000, 0.2000, 0.2000],
        ...,
        [4.7000, 4.7000, 4.7000,  ..., 4.7000, 4.7000, 4.7000],
        [4.8000, 4.8000, 4.8000,  ..., 4.8000, 4.8000, 4.8000],
        [4.9000, 4.9000, 4.9000,  ..., 4.9000, 4.9000, 4.9000]])
torch.Size([50, 50])

x_repeat就是把x_test复制了50份,从一个50个数据的1维数组变成了50X50的矩阵。把x_repeat看成是query,x_train看成是key,y_train看出是value,计算注意力权重,根据注意力权重来计算输出。

#x_train包含着键。attention_weights的形状:(n_test,n_train),
#每⼀⾏都包含着要在给定的每个查询的值(y_train)之间分配的注意⼒权重
attention_weights = nn.functional.softmax(-(x_repeat-x_train)**2 / 2, dim=1)
#y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

3.带参数的注意力机制

这就是正常的注意力机制了。

# X_tile的形状:(n_train,n_train),每⼀⾏都包含着相同的训练输⼊
x_tile = x_train.repeat((n_train,1))
# Y_tile的形状:(n_train,n_train),每⼀⾏都包含着相同的训练输出
y_tile = y_train.repeat((n_train,1))
# keys的形状:('n_train','n_train'-1)
keys = x_tile[(1-torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))
# values的形状:('n_train','n_train'-1)
values = y_tile[(1-torch.eye(n_train)).type(torch.bool)].reshape((n_train,-1))

这里有个小技巧,用来生成keys和values。

print((1-torch.eye(n_train)).type(torch.bool).shape) # 50X50的矩阵,包含True和False,每一行有49个True和1个False
x_tile[(1-torch.eye(n_train)).type(torch.bool)].shape # 用这个值作为索引后,就得到50行49列的数据,因为每一行都有一个值是False

# 输出:
torch.Size([50, 50])
torch.Size([2450])

这里torch.eye(n_train)生成的是50X50的单位矩阵,1-torch.eye(n_train)生成的是50X50的矩阵:

tensor([[0., 1., 1.,  ..., 1., 1., 1.],
        [1., 0., 1.,  ..., 1., 1., 1.],
        [1., 1., 0.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 0., 1., 1.],
        [1., 1., 1.,  ..., 1., 0., 1.],
        [1., 1., 1.,  ..., 1., 1., 0.]])

(1-torch.eye(n_train)).type(torch.bool)生成的是50X50的布尔矩阵:

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [ True, False,  True,  ...,  True,  True,  True],
        [ True,  True, False,  ...,  True,  True,  True],
        ...,
        [ True,  True,  True,  ..., False,  True,  True],
        [ True,  True,  True,  ...,  True, False,  True],
        [ True,  True,  True,  ...,  True,  True, False]])

x_tile[(1-torch.eye(n_train)).type(torch.bool)],代表的是根据bool矩阵取出tensor中对应位置元素,如果是False的元素就不取,取出对应位置为True的元素。

所以keys和values的形状就是[50,49]。下面定义Nadaraya-Waston模型:

class NWKernelRegression(nn.Module):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True)) # 初始化权重参数
        
    def forward(self, queries, keys, values):
        # queries和attenion_weights的形状为(查询个数,“键值对”个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(-((queries-keys)*self.w)**2/2,dim=1)
        return torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1)

模型训练:

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch+1}, loss {float(l.sum()):.6f}')
# 输出:
epoch 1, loss 27.291159
epoch 2, loss 6.903591
epoch 3, loss 6.815420
epoch 4, loss 6.729923
epoch 5, loss 6.646971

可以看到,预测结果已经越来越接近真实数据了,只是曲线还不是很平滑。其实我还尝试了一下用缩放点积模型来构建模型,但是效果并不好,下一篇来看一下多头注意力和自注意力机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1838961.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

重生奇迹MU 探秘奇幻世界

"探秘奇幻世界,成就无尽荣耀!欢迎来到重生奇迹MU,一个永不落幕的游戏乐园。在这里,你可以尽情挑战各种困难,发掘神秘宝藏,还可与来自世界各地的玩家一起创造无尽的历史。为了帮助你更好地探索游戏世界…

用Vue3和ApexCharts打造交互式3D折线图

本文由ScriptEcho平台提供技术支持 项目地址:传送门 Vue.js 中使用 ApexCharts 构建交互式折线图 应用场景 ApexCharts 是一个功能强大的 JavaScript 库,用于创建交互式、可定制的图表。在 Vue.js 中,它可以通过 vue3-apexcharts 插件轻松…

IPython大揭秘:神奇技巧让你掌握无敌编程力量!

IPython技巧 基础技巧文件操作技巧输入输出技巧魔术命令技巧调试技巧程序性能优化技巧输入输出重定向技巧魔术命令控制技巧自定义显示格式技巧多线程多进程技巧异常处理技巧数据可视化技巧自定义魔术命令技巧安装扩展包技巧Jupyter Notebook集成技巧文档显示技巧代码块执行技巧…

2024 年 Python 基于 Kimi 智能助手 Moonshot Ai 模型搭建微信机器人(更新中)

注册 Kimi 开放平台 Kimi:https://www.moonshot.cn/ Kimi智能助手是北京月之暗面科技有限公司(Moonshot AI)于2023年10月9日推出的一款人工智能助手,主要为用户提供高效、便捷的信息服务。它具备多项强大功能,包括多…

【Orange Pi 5与Linux内核编程】-理解Linux内核中的container_of宏

理解Linux内核中的container_of宏 文章目录 理解Linux内核中的container_of宏1、了解C语言中的struct内存表示2、Linux内核的container_of宏实现理解3、Linux内核的container_of使用 Linux 内核包含一个名为 container_of 的非常有用的宏。本文介绍了解 Linux 内核中的 contain…

【软件工程】【22.10】p2

关键字: 软件开发基本途径、初始需求发现技术、UML表达事物之间关系、RUP需求获取基本步骤、项目过程建立涉及工作、项目规划过程域的意图和专用目标 判定表、分支覆盖、条件覆盖 三、简答 四、应用 这里条件覆盖有待商榷

ultralytics 8.2.35增加YOLOv9t/s/m模型全过程

yolov9的小模型开源也有两周左右了,ultralytics两天前新版本已经可以支持使用了。 过一段时间,Yolov10估计也快了。 yolov9的作者代码有一些部分本身就是从yolov5里“借鉴”而来,性能提高没提高见仁见智吧。 yolov10的nms free方式倒是比较…

有了MES、ERP,质量管理为什么还需要QMS?

在制造业,质量管理始终是企业管理中永恒的主题。品质管理要想做得更好,企业必须掌握足够多、足够有用的数据和信息,实现质量管理信息化。很多中小企业也很困惑,是否有必要上线QMS质量管理系统? 一、为什么企业需要QMS的…

C语言实现五子棋教程

Hi~!这里是奋斗的小羊,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 💥💥个人主页:奋斗的小羊 💥💥所属专栏:C语言 🚀本系列文章为个人学习…

【Redis技术进阶之路】「底层源码解析」揭秘高效存储模型与数据结构底层实现(链表)

揭秘高效存储模型与数据结构底层实现 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 链表使用场景List(列表)和 链表的关系链表的实现链表的节点list的源码实现结构模…

申万宏源:消费税改或是财政改革第一枪

消费税征收环节后移可能带来年化千亿的税收收入增长,地方财政压力的缓和程度取决于中央确定保留的消费税基数。申万宏源认为,财政改革不仅仅只涉及消费税和央地分配,而稳定扩大需求才是下一步改革核心。 主要内容 财政现实呼唤改革。紧迫性…

SmartEDA革新电路设计:告别繁琐,轻松步入智能时代!

在数字化浪潮席卷而来的今天,电路设计的复杂性和繁琐性一直是工程师们面临的难题。然而,随着科技的进步,一款名为SmartEDA的电路设计工具应运而生,它以智能化、高效化的特点,彻底颠覆了传统电路设计的方式,…

9.华为交换机telnet远程管理配置aaa认证

目的:telnet远程管理设备 LSW1配置 [Huawei]int Vlanif 1 [Huawei-Vlanif1]ip add 1.1.1.1 24 [Huawei-Vlanif1]q [Huawei]user-interface vty 0 4 [Huawei-ui-vty0-4]authentication-mode aaa [Huawei-ui-vty0-4]q [Huawei]aaa [Huawei-aaa]local-user admin pass…

视频智能分析平台智能边缘分析一体机视频监控业务平台区域人数不足检测算法

智能边缘分析一体机区域人数不足检测算法是一种集成了先进图像处理、目标检测、跟踪和计数等功能的算法,专门用于实时监测和统计指定区域内的人数,并在人数不足时发出警报。以下是对该算法的详细介绍: 一、算法概述 智能边缘分析一体机区域…

编写C语言程序解决多个数学问题及修正斐波那契数列递归函数

目录 请按下列要求编写程序:(三个函数均在一个C语言源程序) 有一个四位整数,它的9倍恰好是其反序数(反序数例:1234与4321互为反序数)。 有3个非零十进制数字,用它们可以组合出6个不同的三位数&#xff0…

Python8 使用结巴(jieba)分词并展示词云

Python的结巴(jieba)库是一个中文分词工具,主要用于对中文文本进行分词处理。它可以将输入的中文文本切分成一个个独立的词语,为后续的文本处理、分析、挖掘等任务提供基础支持。结巴库具有以下功能和特点: 中文分词&a…

【原创】springboot+mysql小区用水监控管理系统设计与实现

个人主页:程序猿小小杨 个人简介:从事开发多年,Java、Php、Python、前端开发均有涉猎 博客内容:Java项目实战、项目演示、技术分享 文末有作者名片,希望和大家一起共同进步,你只管努力,剩下的交…

视频怎么旋转方向?3种旋转视频方法分享

视频怎么旋转方向?视频旋转方向,在视频编辑过程中,扮演着至关重要的角色。这一操作不仅能有效调整视频的视觉呈现,使之更加符合我们的预期,还能解决由于拍摄角度不当导致的画面倾斜问题。通过简单的旋转调整&#xff0…

从网络配置文件中提取PEAP凭据

我的一位同事最近遇到了这样一种情况:他可以物理访问使用802.1X连接到有线网络的Windows计算机,同时保存了用于身份验证的用户凭据,随后他想提取这些凭据,您可能认为这没什么特别的,但是事情却有点崎岖波折…… 如何开…

shell脚本监控docker容器和supervisor 运行情况

1.ASR服务 需求: 在ASR服务器中 docker 以下操作中 忽略容器名字叫 nls-cloud-mongodb 的容器 在ASR服务器中 docker ps 查看正在运行的容器 docker stats -a --no-stream 可以监控容器所占资源 确认是否有pid且不等于0 docker inspect -f “{{.RestartCount}}” 容器名称 可…