边写代码边学习之TF Attention

news2025/1/30 16:40:01

1. 什么是Attention

注意力机制(Attention Mechanism)是机器学习和人工智能领域中的一个重要概念,用于模拟人类视觉或听觉等感知过程中的关注机制。注意力机制的目标是让模型能够在处理信息时,更加关注与任务相关的部分,忽略与任务无关的信息。这种机制最初是受到人类大脑对信息处理的启发而提出的。

注意力机制的基本原理如下:

  1. 输入信息:首先,注意力机制接收输入信息,这可以是序列数据、图像、语音等。

  2. 查询、键和值:对于每个输入,注意力机制引入了三个部分:查询(query)、键(key)、值(value)。这些部分通常是通过神经网络学习得到的。查询用于表示要关注的内容,键用于表示输入信息中的特征,值则是与每个键相关的信息。

  3. 权重分配:注意力机制根据查询和键之间的关系来计算权重,这些权重决定了每个值在最终输出中的贡献程度。通常使用某种形式的相似度度量(如点积、缩放点积等)来计算权重。

  4. 加权求和:将计算得到的权重与对应的值相乘,然后将它们加权求和,得到最终的输出。这个输出通常包含了模型在处理输入信息时关注的部分。

  5. 重复:上述过程通常会被重复多次,以便模型可以在不同的上下文中动态地调整注意力。

注意力机制的核心思想是让模型能够自动地确定在处理输入信息时要关注哪些部分,从而提高了模型在各种任务中的性能。它在自然语言处理、计算机视觉和语音处理等领域都有广泛的应用,如在机器翻译中的Transformer模型、图像分割中的U-Net模型以及语音识别中的Listen, Attend and Spell(LAS)模型等。

总的来说,注意力机制可以帮助模型更好地理解和利用输入信息,提高了模型的表现和泛化能力。

2. Why Attention

由于LSTM和GRU只在一定程度上改进了循环神经网络的长句子依赖问题,并且信息的记忆能力也不是很强和计算能力有限。如果模型要记住很多信息,不得不设计的更复杂,为了解决这些问题,注意力机制出现了,它即能从大量信息中选择重要的信息来缓解神经网络模型的复杂度,而且能高效的并行运算。注意力机制的计算是一个匹配的过程,即通过一个查询(Query)向量到键(Key)和值(Value)对数据对来映射输出值.

注意力的计算一般有三个阶段。第一阶段是计算查询向量Q和每个输入的K的相关性或相似度,得到注意力权重系数S_i :

S_i=f(Q,K_i)

第二阶段是使用SoftMax函数对第一阶段得出的权重系数进行尺度缩放,即把它归一化为概率分布 ai ,分子是把神经元的当前输出映射到(0,+∞),分母是所有输出结果值的总和,公式如下:

a _i=softmax (S_i ) = e^{S_i }/(\sum e^{S_j})

第三阶段:将第二阶段得出的权重与value值加权求和,得到最终需要的Attention数值:

Attention(Q,K,V)=\sum a_i V_i

3. TF attention api 介绍

Attention class

tf.keras.layers.Attention(use_scale=False, score_mode="dot", **kwargs)

Dot-product attention layer, a.k.a. Luong-style attention.

Inputs are query tensor of shape [batch_size, Tq, dim]value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

  1. Calculate scores with shape [batch_size, Tq, Tv] as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True).
  2. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]distribution = tf.nn.softmax(scores).
  3. Use distribution to create a linear combination of value with shape [batch_size, Tq, dim]return tf.matmul(distribution, value).

4. 实验代码

4.1.  验证并理解TF attention方法,只输入query和value矩阵。

def softmax(t):
    s_value = np.exp(t) / np.sum(np.exp(t), axis=-1, keepdims=True)
    # print('softmax value: ', s_value)
    return s_value

def numpy_attention(inputs,
        mask=None,
        training=None,
        return_attention_scores=False,
        use_causal_mask=False):

    query = inputs[0]
    value = inputs[1]
    key = inputs[2] if len(inputs) > 2 else value

    score = np.matmul(query, key.transpose())
    attention_score_np = softmax(score)
    result = np.matmul(attention_score_np, value)
    print('attention score in numpy =', attention_score_np)
    print('result in numpy = ', result)


def verify_logic_in_attention_with_query_value():
    query_data = np.array(
        [[1, 0.0, 1],[2, 3, 1]]
    )
    value_data = np.array(
        [[2, 1.0, 1],[1, 4, 2 ]]
    )
    print(query_data.shape)

    numpy_attention([query_data, value_data], return_attention_scores=True)
    print("=============following is keras attention output================")

    attention_layer= tf.keras.layers.Attention()

    result, attention_scores = attention_layer([query_data, value_data], return_attention_scores=True)

    print('attention_scores = ', attention_scores)
    print('result=', result);
if __name__ == '__main__':
    verify_logic_in_attention_with_query_value()

运行结果

(2, 3)
attention score in numpy = [[5.0000000e-01 5.0000000e-01]
 [3.3535013e-04 9.9966465e-01]]
result in numpy =  [[1.5        2.5        1.5       ]
 [1.00033535 3.99899395 1.99966465]]
=============following is keras attention output================
attention_scores =  tf.Tensor(
[[5.0000000e-01 5.0000000e-01]
 [3.3535014e-04 9.9966466e-01]], shape=(2, 2), dtype=float32)
result= tf.Tensor(
[[1.5       2.5       1.5      ]
 [1.0003353 3.998994  1.9996647]], shape=(2, 3), dtype=float32)

4.2.  验证并理解TF attention方法,输入query, key, value矩阵。

def verify_logic_in_attention_with_query_key_value():
    query_data = np.array(
        [[1, 0.0, 1],[2, 3, 1]]
    )
    value_data = np.array(
        [[2, 1.0, 1],[1, 4, 2 ]]
    )
    key_data = np.array(
        [[1, 2.0, 2], [3, 1, 0.1]]
    )
    print(query_data.shape)

    numpy_attention([query_data, value_data, key_data], return_attention_scores=True)
    print("=============following is keras attention output================")

    attention_layer= tf.keras.layers.Attention()

    result, attention_scores = attention_layer([query_data, value_data, key_data], return_attention_scores=True)

    print(attention_layer.get_weights())
    print('attention_scores = ', attention_scores)
    print('result=', result);
if __name__ == '__main__':
    verify_logic_in_attention_with_query_key_value()

结果

(2, 3)
attention score in numpy = [[0.47502081 0.52497919]
 [0.7109495  0.2890505 ]]
result in numpy =  [[1.47502081 2.57493756 1.52497919]
 [1.7109495  1.86715149 1.2890505 ]]
=============following is keras attention output================
[]
attention_scores =  tf.Tensor(
[[0.47502086 0.52497923]
 [0.7109495  0.28905058]], shape=(2, 2), dtype=float32)
result= tf.Tensor(
[[1.4750209 2.5749378 1.5249794]
 [1.7109495 1.8671517 1.2890506]], shape=(2, 3), dtype=float32)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/965098.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TDengine(2):wsl2+ubuntu20.04+TDengine安装

一、ubuntu系统下提供了三种安装TDengine的方式: 二、通过 apt 指令安装失败 因为是linux初学者,对apt 指令较为熟悉,因此首先使用了该方式进行安装。 wget -qO - http://repos.taosdata.com/tdengine.key | sudo apt-key add -echo "…

windows使用-设置windows的远程访问用户数量

文章目录 前言相关操作总结前言 作为IT工程师,使用服务器做相应的软件操作时常有的事。最近一段时间,我们的团队多个成员都需要远程登录到一台windows2003Server的服务器处理相应的业务。而默认情况下,Windows系统只允许一名用户远程到服务器上,这给小伙伴的工作造成一些不…

React-native环境配置与项目搭建

基础环境搭建 安装 node.js (版本>12 ,推荐安装LTS稳定版本) 安装 Yarn (npm install -g yarn) 安装 react native 脚手架 (npm install -g react-native-cli) windows 只能搭建Android 开发环境 Mac 下既能搭建Android 环境&…

斥资4亿,收购这家WLAN厂商,结果……

晚上好,我的网工朋友 不少朋友可能有隐形,2019年的时候,Juniper花费4.05亿美元,收购WiFi初创公司Mist Systems。 Mist Systems是一家买无线产品起家的公司,由前思科高管创建的。主打的产品是“AI-Driven WLAN”&…

linux安装firefox

1.下载对应包 https://www.mozilla.org/en-US/firefox/all/#product-desktop-release 2. 挂载桌面链接(如果/usr/bin/firefox下有的话,先删除) ln -s /opt/firefox/firefox /usr/bin/firefox 3.执行以下命令,即可启动Firefox客户端: firefox

TCP协议报文

前言 TCP/IP协议簇——打开虚拟世界大门中,已经给大家大致介绍了TCP/IP协议簇的分层。 TCP (Transmission Control Protocol)传输控制协议,在TCP/IP协议簇中,处于传输层。是为了在不可靠的互联网络(IP协议)中&#x…

LangChain学习笔记;给老师的ChatGPT使用指南;中国大模型顶级闭门会交流笔记;飞桨开源任务挑战大赛 | ShowMeAI日报

👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 🤖 飞桨PaddlePaddle开源任务挑战大赛,首届「开放原子开源大赛」等你参与 官网:https://competition.atomgit.com…

redis未授权访问

文章目录 搭建环境漏洞复现安装Exlopit并使用 前提条件: 1.安装docker docker pull medicean/vulapps:j_joomla_22.安装docker-compose docker run -d -p 8000:80 medicean/vulapps:j_joomla_23.下载vulhub 搭建环境 输入下面命令,来到Redis的路径下&am…

基于Open3D的点云处理16-特征点匹配

点云配准 将点云数据统一到一个世界坐标系的过程称之为点云配准或者点云拼接。(registration/align) 点云配准的过程其实就是找到同名点对;即找到在点云中处在真实世界同一位置的点。 常见的点云配准算法: ICP、Color ICP、Trimed-ICP 算法…

深入探讨梯度下降:优化机器学习的关键步骤(一)

文章目录 🍀引言🍀什么是梯度下降?🍀损失函数🍀梯度(gradient)🍀梯度下降的工作原理🍀梯度下降的变种🍀随机梯度下降(SGD)🍀批量梯度下降&#xf…

添加YDNS免费的ipv6动态域名解析

背景 又到了一年一度的dns域名到期,寻找替代了,前几年用了阿里、华为的免费域名,支持了几个搭建在NAS上的微服务;一旦涉及到域名续费,价格就比首年上去了不少,所以,打算找个长期的免费域名。 搜…

在Windows 10上部署ChatGLM2-6B:掌握信息时代的智能对话

在Windows 10上部署ChatGLM2-6B:掌握信息时代的智能对话 硬件环境ChatGLM2-6B的量化模型最低GPU配置说明准备工作ChatGLM2-6B安装部署ChatGLM2-6B运行模式解决问题总结 随着当代科技的快速发展,我们进入了一个数字化时代,其中信息以前所未有的…

python数据分析基础—pandas中set_index()、reset_index()的使用

文章目录 一、索引是什么?二、set_index()三、reset_index() 一、索引是什么? 在进行数据分析时,通常我们要根据业务情况进行数据筛选,要求筛选特定情况的行或列,这时就要根据数据类型(Series或者DataFrame)的索引情况…

小苹果他爹V5.8版本最强小苹果影视盒子增加46条内置优质单仓线路

这款软件直接使用了俊版的小苹果接口,并且许多资源似乎都是直接调用的小苹果官方资源。这样一来,小苹果的作者可能会面临版权方面的问题,而且也让更多的用户对小苹果的收费模式产生质疑。在这个信息传播如此快速的时代,开发者们应…

816. 模糊坐标

816. 模糊坐标 原题链接:完成情况:解题思路:参考代码:错误经验吸取 原题链接: 模糊坐标 完成情况: 解题思路: 参考代码: package 西湖算法题解___中等题;import java.util.Arra…

公司文件防泄密系统——「天锐绿盾透明加密系统」

「天锐绿盾透明加密系统」是一种公司文件防泄密系统,从源头上保障数据安全和使用安全。该系统采用文件过滤驱动实现透明加解密,对用户完全透明,不影响用户操作习惯。 PC访问地址: isite.baidu.com/site/wjz012xr/2eae091d-1b97-4…

贝叶斯神经网络 - 捕捉现实世界的不确定性

贝叶斯神经网络 - 捕捉现实世界的不确定性 Bayesian Neural Networks 生活本质上是不确定性和概率性的,贝叶斯神经网络 (BNN) 旨在捕获和量化这种不确定性 在许多现实世界的应用中,仅仅做出预测是不够的;您还想知道您对该预测的信心有多大。例…

ARM Cortex-M 的 SP

文章目录 1、栈2、栈操作3、Cortex-M中的栈4、MDK中的SP操作流程5、Micro-Lib的SP差别1. 使用 Micro-Lib2. 未使用 Micro-Lib 在嵌入式开发中,堆栈是一个很基础,同时也是非常重要的名词,堆栈可分为堆 (Heap) 和栈 (Stack) 。 栈(Stack): 一种…

2010-2021年上市公司和讯网社会责任评级CSR数据/和讯网上市公司社会责任数据

2010-2021年上市公司和讯网社会责任评级CSR数据 1、时间:2010-2021年 2、指标:股票名称、股票代码、年份、总得分、等级、股东责任、员工责任、供应商、客户和消费者权益责任、环境责任、社会责任、所属年份 3、样本量:4万 4、来源&#…

网工内推 | 上市公司,IT工程师、服务器工程师,IP以上优先

01 烟台睿创微纳技术股份有限公司 招聘岗位:IT工程师 职责描述: 1、负责网络及安全架构的规划、设计、性能优化; 2、负责网络设备的安装、配置、管理、排错、维护,提供网络设备维护方案; 3、负责防火墙、上网行为管理…