深度学习理论基础(六)Transformer多头注意力机制

news2024/11/17 12:36:07

目录

  • 一、自定义多头注意力机制
    • 1. 缩放点积注意力(Scaled Dot-Product Attention)
      • ● 计算公式
      • ● 原理
    • 2. 多头注意力机制框图
      • ● 具体代码
  • 二、pytorch中的子注意力机制模块

  
  深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息,提高模型的性能和泛化能力。
  下图 展示了人类在看到一幅图像时如何高效分配有限注意力资源的,其中红色区域表明视觉系统更加关注的目标,从图中可以看出:人们会把注意力更多的投入到人的脸部。文本的标题以及文章的首句等位置。而注意力机制就是通过机器来找到这些重要的部分。
在这里插入图片描述

一、自定义多头注意力机制

1. 缩放点积注意力(Scaled Dot-Product Attention)

  缩放点积注意力(Scaled Dot-Product Attention)是注意力机制的一种形式,通常在自注意力(self-attention)机制或多头注意力机制中使用,用于模型在处理序列数据时关注输入序列中不同位置的信息。这种注意力机制常用于Transformer模型及其变体中,被广泛用于各种自然语言处理任务,如机器翻译、文本生成和问答系统等。
在这里插入图片描述

● 计算公式

在这里插入图片描述

● 原理

假设输入:给定一个查询向量(query)、一组键向量(keys)和一组值向量(values)。

(1)Dot-Product 计算相似度:通过计算查询向量query与键向量keys之间的点积,得到每个查询与所有键的相似度分数。然后将这些分数进行缩放(scale)–除以根号下d_k,以防止点积的值过大,从而导致梯度消失或梯度爆炸。
(2)Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.
(3)Softmax归一化:对相似度分数进行softmax归一化,得到每个键的权重,这些权重表示了对应值向量的重要程度。
加权求和:使用这些权重对值向量进行加权求和,得到最终的注意力输出。
在这里插入图片描述

2. 多头注意力机制框图

  多头注意力机制是在 Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。
在这里插入图片描述

● 具体代码

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
 
class MultiHeadAttention(nn.Module):
	#embedding_dim:输入向量的维度,num_heads:注意力机制头数
    def __init__(self, embedding_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads       #总头数
        self.embedding_dim = embedding_dim   #输入向量的维度
        self.d_k= self.embedding_dim// self.num_heads  #每个头 分配的输入向量的维度数
        self.softmax=nn.Softmax(dim=-1)
 
        self.W_query = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)
        self.W_key = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)
        self.W_value = nn.Linear(in_features=embedding_dim, out_features=embedding_dim, bias=False)
        self.fc_out = nn.Linear(embedding_dim, embedding_dim)
        
   #输入张量 x 中的特征维度分成 self.num_heads 个头,并且每个头的维度为 self.d_k。
	def split_head(self, x, batch_size):
		x = x.reshape(batch_size, -1, self.num_heads, self.d_k)
		return x.permute(0,2,1,3)   #x  (N_size, self.num_heads, -1, self.d_k)
	     
 
    def forward(self, x):
     	batch_size=x.size(0)  #获取输入张量 x 的批量(batch size)大小
        q= self.W_query(x)  
        k= self.W_key(x)  
        v= self.W_value(x)
        
       #使用 split_head 函数对 query、key、value 进行头部切分,将其分割为多个注意力头。
		q= self.split_head(q, batch_size)
		k= self.split_head(k, batch_size)
		v= self.split_head(v, batch_size)
		
		##attention_scorce = q*k的转置/根号d_k
 		attention_scorce=torch.matmul(q, k.transpose(-2,-1))/torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        attention_weight= self.softmax(attention_scorce)
 
        ## output = attention_weight * V
        output = torch.matmul(attention_weight, v)  # [h, N, T_q, num_units/h]
        output  = out.permute(0,2,1,3).contiguous() # [N, T_q, num_units]
 		output  = out.reshape(batch_size,-1, self.embedding_dim)
		output  = self.fc_out(output)
		
        return output

  

二、pytorch中的子注意力机制模块

  nn.MultiheadAttention是PyTorch中用于实现多头注意力机制的模块。它允许你在输入序列之间计算多个注意力头,并且每个头都学习到了不同的注意力权重。
  创建了一些随机的输入数据,包括查询(query)、键(key)、值(value)。接着,我们使用multihead_attention模块来计算多头注意力,得到输出和注意力权重。
  请注意,你可以调整num_heads参数来控制多头注意力的头数,这将会影响到模型的复杂度和表达能力。

import torch
import torch.nn as nn

# 假设我们有一些输入数据
# 输入数据形状:(序列长度, 批量大小, 输入特征维度)
input_seq_length = 10
batch_size = 3
input_features = 32

# 假设我们的输入序列是随机生成的
input_data = torch.randn(input_seq_length, batch_size, input_features)

# 定义多头注意力模块
# 参数说明:
#   - embed_dim: 输入特征维度
#   - num_heads: 多头注意力的头数
#   - dropout: 可选,dropout概率,默认为0.0
#   - bias: 可选,是否在注意力计算中使用偏置,默认为True
#   - add_bias_kv: 可选,是否添加bias到key和value,默认为False
#   - add_zero_attn: 可选,是否在注意力分数中添加0,默认为False
multihead_attention = nn.MultiheadAttention(input_features, num_heads=4)

# 假设我们有一个query,形状为 (查询序列长度, 批量大小, 输入特征维度)
query = torch.randn(input_seq_length, batch_size, input_features)

# 假设我们有一个key和value,形状相同为 (键值序列长度, 批量大小, 输入特征维度)
key = torch.randn(input_seq_length, batch_size, input_features)
value = torch.randn(input_seq_length, batch_size, input_features)

# 计算多头注意力
# 返回值说明:
#   - output: 注意力计算的输出张量,形状为 (序列长度, 批量大小, 输入特征维度)
#   - attention_weights: 注意力权重,形状为 (批量大小, 输出序列长度, 输入序列长度)
output, attention_weights = multihead_attention(query, key, value)

# 输出结果
print("Output shape:", output.shape)
print("Attention weights shape:", attention_weights.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1573838.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

docker基础学习指令

文章目录 [toc] docker基础常用指令一、docker 基础命令二、docker 镜像命令1. docker images2. docker search3. docker pull4. docker system df5. docker rmi1. Commit 命令 三、 docker 容器命令1. docker run2. docker logs3. docker top4. docker inspect5. docker cp6. …

Mybatis报错:Unsupported conversion from LONG to java.sql.Timestamp

Mybatis在封装结果集的时候,如果方法返回的是对象,则会去调用这个对象的无参构造方法。 如果实体类标注了Builder注解,则此注解会把默认的构造方法全部改成私有的,则Mybatis在通过无参构造方法反射创建对象时,就会找不…

#QT项目实战(天气预报)

1.IDE:QTCreator 2.实验: 3.记录: (1)调用API的Url a.调用API获取IP whois.pconline.com.cn/ipJson.jsp?iphttp://whois.pconline.com.cn/ipJson.jsp?ip if(window.IPCallBack) {IPCallBack({"ip":&quo…

Python学习之-魔术方法

前言: Python 中的魔术方法(Magic Methods),也称作特殊方法(Special Methods),是那些被双下划线包围的方法,例如 init。这些方法在 Python 中有特殊的含义,它们并不需要…

ThingsBoard通过MQTT发送遥测数据

MQTT基础 客户端 MQTT连接 遥测上传API 案例 MQTT基础 MQTT是一种轻量级的发布-订阅消息传递协议,它可能最适合各种物联网设备。 你可以在此处找到有关MQTT的更多信息,ThingsBoard服务器支持QoS级别0(最多一次)和QoS级别1&…

程序猿成长之路之数据挖掘篇——频繁项集挖掘介绍

频繁项集挖掘可以说是数据挖掘中的重点,下面我们来分析以下频繁项集挖掘的过程和目标 如果对数据挖掘没有概念的小伙伴可以查看上次的文章 https://blog.csdn.net/qq_31236027/article/details/137046475 什么是频繁项集? 在回答这个问题之前&#xff…

蓝桥杯第101题 拉马车 C++ Java Python

目录 题目 思路和解题方法 复杂度: c 代码 Java 版本(仅供参考) Python 版本(仅供参考) 代码细节 C 版本: Java 版本: Python 版本: 题目 思路和解题方法 这个游戏是一个简单的纸牌游戏,两个玩家轮流出牌&am…

Springboot相关知识-图片描述(学习笔记)

学习java过程中的一些笔记,觉得比较重要就顺手记录下来了~ 目录 一、前后端请求1.前后端交互2.简单传参3.数组集合传参4.日期参数5.Json参数6.路径参数7.响应数据8.解析xml文件9.统一返回类10.三层架构11.分层解耦12.Bean的声明13.组件扫描14.自动注入 一、前后端请…

(免费分享)基于springboot,vue问卷调查系统

用户注册、用户登录、创建调查问卷、编辑问卷问题和选型(支持题型:单选、多选、单行文本、多行文本、数字、评分、日期、文本描述)、保存和发布问卷、停止问卷调查、游客填写调查问卷(一个IP地址只能填写一次) 技术&a…

4.3 IO day5

1&#xff1a;实现文件夹的拷贝功能 注意判断被拷贝的文件夹是否存在&#xff0c;如果不存在则提前创建&#xff0c;创建文件夹的函数为 mkdir 不考虑递归拷贝的问题 #include <stdio.h> #include <string.h> #include <stdlib.h> #include <sys/types.h…

蓝桥杯刷题-09-三国游戏-贪心⭐⭐⭐

蓝桥杯2023年第十四届省赛真题-三国游戏 小蓝正在玩一款游戏。游戏中魏蜀吴三个国家各自拥有一定数量的士兵X, Y, Z (一开始可以认为都为 0 )。游戏有 n 个可能会发生的事件&#xff0c;每个事件之间相互独立且最多只会发生一次&#xff0c;当第 i 个事件发生时会分别让 X, Y,…

清明作业 c++

1.封装一个类&#xff0c;实现对一个数求累和阶乘质数 #include <iostream>using namespace std; int mproduct(int a){if(a>1){return a*mproduct((a-1));}else{return 1;} } class number{int a; public:number():a(5){};number(int a):a(a){}void set(int a){thi…

开源的页面生成器:拖拽即可生成小程序、H5页面和网站

星搭精卫 MtBird 是一款低代码可视化页面生成器&#xff0c;可以帮助用户以可视化的形式搭建网页、小程序和表单等应用。 使用这个生成器&#xff0c;不需要代码就可以生成小程序、H5页面和网站&#xff0c;拖拽操作、样式配置快速生成页面应用&#xff0c;数据可视化接入&…

2024年【道路运输企业安全生产管理人员】找解析及道路运输企业安全生产管理人员作业考试题库

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 道路运输企业安全生产管理人员找解析参考答案及道路运输企业安全生产管理人员考试试题解析是安全生产模拟考试一点通题库老师及道路运输企业安全生产管理人员操作证已考过的学员汇总&#xff0c;相对有效帮助道路运输…

C++高频面试知识总结 part3

哈希 1.哈希表为什么快&#xff1f;2.哈希冲突解决方法3.哈希表扩容流程4.哈希表扩容太多次&#xff0c;需要遍历所有元素&#xff0c;如何优化&#xff1f;5.渐进式扩容为何可以正确访问哈希表&#xff1f; 1.哈希表为什么快&#xff1f; 哈希表&#xff08;Hash table&#…

告别旧IP,更换网络ip地址教程分享

在数字化世界中&#xff0c;IP地址作为每个网络设备的标识符&#xff0c;扮演着至关重要的角色。它不仅是设备在网络中的“门牌号”&#xff0c;还影响着网络连接的稳定性、安全性和数据传输效率。因此&#xff0c;在某些情况下&#xff0c;更换网络IP地址成为必要操作。虎观代…

平台规则的改变会影响低价治理结果吗

许多品牌在做低价链接投诉时&#xff0c;会用一套自己的标准去做&#xff0c;但如果无视平台规则&#xff0c;会出现非常多不好的结果&#xff0c;比如帐号投诉成功率被拉低&#xff0c;会直接影响后续链接的投诉时效和成功率&#xff0c;同时因为不尊重平台规则&#xff0c;而…

美国洛杉矶大带宽服务器带宽堵塞解决方法

随着互联网的快速发展&#xff0c;大带宽服务器成为了现代企业和个人进行数据传输、存储和处理的关键设施。然而&#xff0c;在美国洛杉矶等大城市&#xff0c;由于网络流量的激增、不合理的网络配置以及网络攻击等多种原因&#xff0c;大带宽服务器带宽堵塞问题日益凸显。本文…

【力扣】94. 二叉树的中序遍历、144. 二叉树的前序遍历、145. 二叉树的后序遍历

先序遍历&#xff1a;根-左-右中序遍历&#xff1a;左-根-右后序遍历&#xff1a;左-右-根 94. 二叉树的中序遍历 题目描述 给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 示例 1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&#xff1a;[1,3…

基于springboot+vue+Mysql的教学视频点播系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…