注意力机制(代码实现案例)

news2024/12/26 11:10:50

学习目标

  • 了解什么是注意力计算规则以及常见的计算规则.
  • 了解什么是注意力机制及其作用.
  • 掌握注意力机制的实现步骤.

1 注意力机制介绍

1.1 注意力概念

  • 我们观察事物时,之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物最具有辨识度的部分从而作出判断,而并非是从头到尾的观察一遍事物后,才能有判断结果. 正是基于这样的理论,就产生了注意力机制.

1.2 注意力计算规则

  • 它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示. 当输入的Q=K=V时, 称作自注意力计算规则.

1.3 常见的注意力计算规则

  • bmm运算演示:

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

2 什么是注意力机制

  • 注意力机制是注意力计算规则能够应用的深度学习网络的载体, 同时包括一些必要的全连接层以及相关张量处理, 使其与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
  • 说明: NLP领域中, 当前的注意力机制大多数应用于seq2seq架构, 即编码器和解码器模型.

3 注意力机制的作用

  • 在解码器端的注意力机制: 能够根据模型目标有效的聚焦编码器的输出结果, 当其作为解码器的输入时提升效果. 改善以往编码器输出是单一定长张量, 无法存储过多信息的情况.
  • 在编码器端的注意力机制: 主要解决表征问题, 相当于特征提取过程, 得到输入的注意力表示. 一般使用自注意力(self-attention).

注意力机制在网络中实现的图形表示:

4 注意力机制实现步骤

4.1 步骤

  • 第一步: 根据注意力计算规则, 对Q,K,V进行相应的计算.
  • 第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接, 如果是转置点积, 一般是自注意力, Q与V相同, 则不需要进行与Q的拼接.
  • 第三步: 最后为了使整个attention机制按照指定尺寸输出, 使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示.

4.2 代码实现

  • 常见注意力机制的代码分析:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class Attn(nn.Module):
        def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
            """初始化函数中的参数有5个, query_size代表query的最后一维大小
               key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, 
               value = (1, value_size1, value_size2)
               value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""
            super(Attn, self).__init__()
            # 将以下参数传入类中
            self.query_size = query_size
            self.key_size = key_size
            self.value_size1 = value_size1
            self.value_size2 = value_size2
            self.output_size = output_size
    
            # 初始化注意力机制实现第一步中需要的线性层.
            self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
    
            # 初始化注意力机制实现第三步中需要的线性层.
            self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
    
    
        def forward(self, Q, K, V):
            """forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
               张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""
    
            # 第一步, 按照计算规则进行计算, 
            # 我们采用常见的第一种计算规则
            # 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果
            attn_weights = F.softmax(
                self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)
    
            # 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, 
            # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算
            attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
    
            # 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, 
            # 需要将Q与第一步的计算结果再进行拼接
            output = torch.cat((Q[0], attn_applied[0]), 1)
    
            # 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
            # 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度
            output = self.attn_combine(output).unsqueeze(0)
            return output, attn_weights
    

  • 调用:
  • query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64
    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

  • 输出效果:
    tensor([[[ 0.4477, -0.0500, -0.2277, -0.3168, -0.4096, -0.5982,  0.1548,
              -0.0771, -0.0951,  0.1833,  0.3128,  0.1260,  0.4420,  0.0495,
              -0.7774, -0.0995,  0.2629,  0.4957,  1.0922,  0.1428,  0.3024,
              -0.2646, -0.0265,  0.0632,  0.3951,  0.1583,  0.1130,  0.5500,
              -0.1887, -0.2816, -0.3800, -0.5741,  0.1342,  0.0244, -0.2217,
               0.1544,  0.1865, -0.2019,  0.4090, -0.4762,  0.3677, -0.2553,
              -0.5199,  0.2290, -0.4407,  0.0663, -0.0182, -0.2168,  0.0913,
              -0.2340,  0.1924, -0.3687,  0.1508,  0.3618, -0.0113,  0.2864,
              -0.1929, -0.6821,  0.0951,  0.1335,  0.3560, -0.3215,  0.6461,
               0.1532]]], grad_fn=<UnsqueezeBackward0>)
    
    
    tensor([[0.0395, 0.0342, 0.0200, 0.0471, 0.0177, 0.0209, 0.0244, 0.0465, 0.0346,
             0.0378, 0.0282, 0.0214, 0.0135, 0.0419, 0.0926, 0.0123, 0.0177, 0.0187,
             0.0166, 0.0225, 0.0234, 0.0284, 0.0151, 0.0239, 0.0132, 0.0439, 0.0507,
             0.0419, 0.0352, 0.0392, 0.0546, 0.0224]], grad_fn=<SoftmaxBackward>)
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1490849.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32(16)使用串口向电脑发送数据

发送字节 发送数组 发送字符和字符串 字符&#xff1a; 字符串&#xff1a; 字符串在电脑中以字符数组的形式存储

最新AI系统ChatGPT网站H5系统源码,支持Midjourney绘画

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;那么如何搭建部署AI创作ChatGPT&#xff1f;小编这里写一个详细图文教程吧。已支持GPT…

《互联网的世界》第四讲-拥塞控制与编码

需要澄清的一个误区是&#xff0c;拥塞绝不是发送的数据量太大导致&#xff0c;而是数据在极短的时间段内到达了同一个地方以至于超过了网络处理容量导致&#xff0c;拥塞的成因一定要考虑时间因素。换句话说&#xff0c;拥塞由大突发导致。 只要 pacing&#xff0c;再多的数据…

MongoDB开启事务

MongoDB开启事务 配置单节点。到路径C:\Program Files\MongoDB\Server\4.0\bin 使用记事本以管理员权限打开文件mongod.cfg添加如下配置&#xff1a; replication:replSetName: rs02. 重启MongoDB服务 3. 重启后执行命令 rs.initiate()

Python采集学习笔记-request的get请求和post请求

使用http://httpbin.org测试,一个简单的 HTTP 请求和响应服务。(需联网)1.导入requests包 import requests 2.测试get请求 url http://httpbin.org/get par {key1: value1, key2: value2} # 不带参数请求 r1 requests.get(url) # 带参数请求 r2 requests.get(url, paramspa…

数据库管理-第158期 Oracle Vector DB AI-09(20240304)

数据库管理158期 2024-03-04 数据库管理-第158期 Oracle Vector DB & AI-09&#xff08;20240304&#xff09;1 创建示例表2 添加过滤条件的向量近似查询示例1示例2示例3示例4示例5示例6示例7 总结 数据库管理-第158期 Oracle Vector DB & AI-09&#xff08;20240304&a…

怎么采集GBK或GB2312等特殊字符编码的网站数据

如果要采集的网站是GBK或GB2312等特殊字符编码&#xff0c;采集结果可能是一堆看不懂的文字或乱码&#xff0c;无法使用。 通常网页文章采集工具有字符编码选项&#xff0c;默认是UTF-8&#xff08;现在大部分网站都是&#xff09;&#xff0c;改选为GBK或GB2312字符编码即可&…

无人机/飞控--ArduPilot、PX4学习历程记录(1)

本篇博客用来记录个人学习记录&#xff0c;存放各种文章链接、视频链接、学习历程、实验过程和结果等等.... 最近在整无人机项目&#xff0c;接触一下从来没有接触过的飞控...(听着就头晕)&#xff0c;本人纯小白。 目录 PX4、Pixhawk、APM、ArduPilot、Dronecode Dronekit…

【ArcPy】游标访问几何数据

访问质心坐标相关数据 结果展示 代码 import arcpy shppath r"C:\Users\admin\Desktop\excelfile\a2.shp" with arcpy.da.SearchCursor(shppath, ["SHAPE","SHAPEXY","SHAPETRUECENTROID","SHAPEX","SHAPEY",&q…

【Mybatis】批量映射优化 分页插件PageHelper 逆向工程插件MybatisX

文章目录 一、Mapper批量映射优化二、插件和分页插件PageHelper2.1 插件机制和PageHelper插件介绍2.2 PageHelper插件使用 三、逆向工程和MybatisX插件3.1 ORM思维介绍3.2 逆向工程3.3 逆向工程插件MyBatisX使用 总结 一、Mapper批量映射优化 需求: Mapper 配置文件很多时&…

高精准无人机激光雷达标定板

无人机激光雷达标定板是一种用于校准无人机激光雷达系统的工具&#xff0c;它可以帮助无人机获取更准确、更可靠的数据&#xff0c;从而提高无人机的导航精度和自主控制能力。本文将从无人机激光雷达标定板的基本概念、作用、应用领域、市场现状和发展趋势等方面进行介绍。 一…

LeetCode 1976.到达目的地的方案数:单源最短路的Dijkstra算法

【LetMeFly】1976.到达目的地的方案数&#xff1a;单源最短路的Dijkstra算法 力扣题目链接&#xff1a;https://leetcode.cn/problems/number-of-ways-to-arrive-at-destination/ 你在一个城市里&#xff0c;城市由 n 个路口组成&#xff0c;路口编号为 0 到 n - 1 &#xff…

【音视频开发好书推荐1】《RTC程序设计:实时音视频权威指南》

1、WebRTC概述 WebRTC&#xff08;Web Real-Time Communication&#xff09;是一个由Google发起的实时音视频通讯C开源库&#xff0c;其提供了音视频采集、编码、网络传输&#xff0c;解码显示等一整套音视频解决方案&#xff0c;我们可以通过该开源库快速地构建出一个音视频通…

12-Linux部署Zookeeper集群

Linux部署Zookeeper集群 简介 ZooKeeper是一个分布式的&#xff0c;开放源码的分布式应用程序协调服务&#xff0c;是Hadoop和Hbase的重要组件。它是一个为分布式应用提供一致性服务的软件&#xff0c;提供的功能包括&#xff1a;配置维护、域名服务、分布式同步、组服务等。…

设计模式:六大原则 ③

一、六大设计原则 &#x1f360; 开闭原则 (Open Close Principle) &#x1f48c; 对扩展开放&#xff0c;对修改关闭。在程序需要进行拓展的时候&#xff0c;不能去修改原有的代码&#xff0c;实现一个热插拔的效果。简言之&#xff0c;是为了使程序的扩展性好&#xff0c;易…

4步教你完成一篇让人挑不出毛病的产品需求文档!

“需求”这个词是产品经理工作中的常客&#xff0c;产品需求文档也贯穿于整个产品经理的日常工作中&#xff0c;本周小编将通过什么是产品需求文档&#xff0c;产品需求文档的作用、如何写好产品需求文档等方面分享如何写出一篇让你挑不出毛病的PRD&#xff0c;让需求文档助力产…

光分路器概述

光分路器主要有两种 技术&#xff1a; ㈠平面波导型光分路器(PLC Splitter) PLC分路器的封装是指将平面波导分路器上的各个导光通路&#xff08;即波导通路&#xff09;与光纤阵列中的光纤一一对准&#xff0c;然后用特定的胶&#xff08;如环氧胶&#xff09;将其粘合在一起…

计算机系统中的文件系统梳理

看之前&#xff0c;大家动动小手点个关注&#xff0c;谢谢。 原文地址&#xff1a;计算机系统中的文件系统梳理 - Pleasure的博客 下面是正文内容&#xff1a; 前言 这是一篇笔记 我之所以要选择这个话题&#xff0c;是因为前几天在对TF卡进行格式化的时候遇到了问题。有些专…

每日一练 | 华为认证真题练习Day192

1、下面是路由器RTB的部分输出信息&#xff0c;关于输出信息描迷错误的是: A. 接口上动态加入的组播组个数是1。 B. 加入的组播组地址是225.1.1.2 C. DISPLAY IGMP GROUP命令用来查看IGMP组播组信息&#xff0c;包括通过成员报告动态加入的组播组和通过命令行静态加入的组播组…

抖音视频下载软件|视频批量采集工具

便捷操作&#xff0c;高效采集 在快节奏的数字化时代&#xff0c;我们的视频下载软件提供了简单便捷的操作流程&#xff0c;让用户能够高效地采集所需视频内容。用户只需输入关键词并点击开始抓取&#xff0c;系统会自动搜索指定关键词下的抖音视频数据&#xff0c;并将待解析视…