【深度学习入门篇 ⑪】自注意力机制

news2024/9/24 6:26:40

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


自注意力背景

NLP领域中, 当前的注意力机制大多数应用于seq2seq架构, 即编码器和解码器模型。

  • encoder-decoder 结构 : Encoder将输入编码成上下文向量,Decoder进行解码;解码过程顺序进行,每次仅解码出一个单词。

RNN存在一些问题:

  1. 输 入 输 出 存 在 序 列 关 系 , b 4 的 输 出 需 要 先 依 赖 于 b 3 , … ,
    一 次 输 出 , 无 法 进 行 并 行 化
  2. 不论输入和输出的语句长度是什么,中间的上下文向量长度都是
    固定的
  3. 仅仅利用上下文向量解码,会有信息瓶颈,长度过长时候信息可
    能会丢失

可以对对seq2seq结构改进,使 用 C N N 来 进 行 并 行 化。

通过堆叠多层CNN,提高感受野,使上层输出可以捕获长程时序关系。

自注意力

语言的含义是极度依赖上下文的

  • 机器人第二法则:机器人必须遵守人类给的命令,除非该命令违背了第一法则

这句话中高亮表示了三个地方,这三处单词指代的是其它单词。除非我们知道这些词
指代的上下文联系起来,否则根本不可能理解或处理这些词语的意思。当模型处理这
句话的时候,它必须知道:

  •  「它」指代机器人
  • 「命令」指代前半句话中人类给机器人下的命令,即「人类给它的命令」
  • 「第一法则」指机器人第一法则的完整内容

自注意力机制(self-Attention):

 3个人工定义的重要概念,查询向量,键向量,值向量

① 查询向量(Query向量):被用来和其它单词的键向量相乘,从而得到其它词相对于当前词的注意力得分。
② 键向量(Key向量):序列中每个单词的标签,是我们搜索相关单词时用来匹配的对象。
③ 值向量(Value向量):单词真正的表征,使用值向量基于注意力得分进行加权求和。

 

查询向量就像一张便利贴,键向量像是档案柜中文件夹上贴的标签。当找到和便利贴上所写相匹
配的文件夹时,文件夹里的东西便是值向量。

自注意力实现

q u e r y , ke y , va l u e 向 量 的 定 义

使用每一个q对每一个k做attention :

将Query和Key分别计算相似性,然后经过softmax得到相似性概率权重,即注意力,再乘以Value,最后相加即可得到包含注意力的输出 。

常见注意力机制代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        super(Attn, self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        # 第一步中需要的线性层
        self.attn = nn.Linear(self.query_size + self.key_size, value_size1)

        # 第三步中需要的线性层
        self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)


    def forward(self, Q, K, V):
        attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)

        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)

        output = torch.cat((Q[0], attn_applied[0]), 1)

        # 使用线性层作用在第三步的结果上做一个线性变换并扩展维度
        output = self.attn_combine(output).unsqueeze(0)
        return output, attn_weights


query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
out = attn(Q, K ,V)
print(out[0])
print(out[1])

输出:

tensor([[[-0.3390,  0.3021, -0.1952, -0.0400,  0.5597, -0.3745, -0.2216,
          -0.3438, -0.2086, -0.1554, -0.2502,  0.0486,  1.0381, -0.1030,
           0.7277,  0.0592, -0.9172, -0.3736, -0.2285, -0.0148, -0.3319,
           0.0620, -0.6006,  0.1346, -0.1530,  0.0336,  0.3269, -0.2511,
          -0.1209,  0.4153,  0.3519,  0.3344, -0.0496, -0.2759, -0.2080,
          -0.1669,  0.7263, -0.0893,  0.0298, -0.1326,  0.6898, -0.3864,
          -0.0884, -0.2329, -0.2338,  0.1920,  0.2625,  0.0396, -0.3101,
          -0.2299, -0.1226, -0.5915,  0.2620,  0.2462,  0.4123, -0.6733,
          -0.2091,  0.6727,  0.3754, -0.1620, -0.8333,  0.2066,  0.3082,
          -0.5225]]], grad_fn=<UnsqueezeBackward0>)
tensor([[0.0187, 0.0492, 0.0259, 0.0293, 0.0151, 0.0104, 0.0127, 0.0122, 0.0546,
         0.0141, 0.0170, 0.0277, 0.0284, 0.0807, 0.0228, 0.0099, 0.0327, 0.0585,
         0.0102, 0.0106, 0.0598, 0.0208, 0.0403, 0.0241, 0.0896, 0.0230, 0.0371,
         0.0316, 0.0091, 0.0242, 0.0553, 0.0447]], grad_fn=<SoftmaxBackward0>)

Self-attention就本质上是一种特殊的attention。Self-attention向对于attention的变化,就是寻找权重值的𝑤𝑖过程不同。

Self-attention和Attention使用方法

  • Attention (AT) 经常被应用在从编码器(encoder)转换到解码器(decoder)。
  • SA可以在一个模型当中被多次的、独立的使用(比如说在Transformer中,使用了18次;在Bert当中使用12次)。
  • SA比较擅长在一个序列当中,寻找不同部分之间的关系,AT却更擅长寻找两个序列之间的关系

Transformer模型

Encoder由N个相同结构的编码模块堆积而成,每一个编码模块由Multi-Head Attention, Add &
Norm, Feed Forward, Add & Norm 组成的。 

编码器结构

第一层的激活函数为 ReLU,第二层不使用激活函数。X是输入,全连接层的输入和输出都是512维,中间隐层维度为2048。 

解码器结构

通过输入矩阵X计算得到Q, K, V 矩阵,然后计算 Q 和 KT 的乘积 QKT。

计算注意力分数,在 Softmax 之前需要使用 Mask矩阵遮挡住每一个单词之后的信息。

  

文本嵌入层:无论是源文本嵌入还是目标文本嵌入,都是为了将文本中词汇的数字表示转变为向量表示, 希望在这样的高维空间捕捉词汇间的关系

import torch
import torch.nn as nn
import math
from torch.autograd import Variable

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

 输出:

embedding = nn.Embedding(10, 3)
input = torch.LongTensor([[1,2,3,4],[6,3,2,9]])
print(embedding(input))

#
tensor([[[ 1.8450,  1.9222,  0.1577],
         [-0.7341,  0.3091,  0.7592],
         [-0.4300,  0.9030, -0.3533],
         [ 1.1873,  0.9349, -1.0567]],

        [[ 0.4812, -0.1072,  0.4980],
         [-0.4300,  0.9030, -0.3533],
         [-0.7341,  0.3091,  0.7592],
         [-2.1227, -0.3621,  0.7383]]], grad_fn=<EmbeddingBackward0>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942671.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于微信小程序+SpringBoot+Vue的大学生科技竞赛管理系统(带1w+文档)

基于微信小程序SpringBootVue的大学生科技竞赛管理系统(带1w文档) 基于微信小程序SpringBootVue的大学生科技竞赛管理系统(带1w文档) 本系统中采用的开发工具包括软件工具和硬件工具&#xff0c;软件采用了Java语言和MySQL数据库&#xff0c;利用微信小程序技术&#xff0c;框架…

从零训练一个多模态LLM:预训练+指令微调+对齐+融合多模态+链接外部系统

本文尝试梳理一个完整的多模态LLM的训练流程。包括模型结构选择、数据预处理、模型预训练、指令微调、对齐、融合多模态以及链接外部系统等环节。 01 准备阶段 1 模型结构 目前主要有三种模型架构&#xff0c;基于Transformer解码器&#xff0c;基于General Language Model&a…

51单片机嵌入式开发:16、STC89C52RC 嵌入式之 步进电机28BYJ48、四拍八拍操作

STC89C52RC 嵌入式之 步进电机28BYJ48、四拍八拍操作 STC89C52RC 之 步进电机28BYJ48操作1 概述1.1 步进电机概述1.2 28BYJ48概述 2 步进电机工作原理2.1 基本原理2.2 28BYJ48工作原理2.3 28BYJ48控制原理 3 电路及软件代码实现4 步进电机市场价值 STC89C52RC 之 步进电机28BYJ…

英语(二)-我的学习方式

章节章节汇总我的学习方式历年真题作文&范文 目录 1、背单词 2、学语法 3、做真题 4、胶囊助学计划 写在最前&#xff1a;我是零基础&#xff0c;初二就听天书的那种。 本专栏持续更新学习资料 1、背单词 单词是基础&#xff0c;一定要背单词&#xff01;考纲要求要…

瑞吉外卖学习(一)

pom文件的导入中 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.6.6</version><relativePath/> <!-- lookup parent from repository --></…

【STM32 HAL库】DMA+串口

DMA 直接存储器访问 DMA传输&#xff0c;将数据从一个地址空间复制到另一个地址空间。-----“数据搬运工”。 DMA传输无需CPU直接控制传输&#xff0c;也没有中断处理方式那样保留现场和恢复现场&#xff0c;它是通过硬件为RAM和IO设备开辟一条直接传输数据的通道&#xff0c…

构建网络安全之盾:应对“微软蓝屏”教训的全面策略

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

【算法】一致性哈希

一、引言 在分布式系统中&#xff0c;数据存储和访问的均匀性、高可用性以及可扩展性一直是核心问题。一致性哈希算法&#xff08;Consistent Hashing&#xff09;是一种分布式算法&#xff0c;因其出色的分布式数据存储特性&#xff0c;被广泛应用于缓存、负载均衡、数据库分片…

【Django5】模板引擎

系列文章目录 第一章 Django使用的基础知识 第二章 setting.py文件的配置 第三章 路由的定义与使用 第四章 视图的定义与使用 第五章 二进制文件下载响应 第六章 Http请求&HttpRequest请求类 第七章 会话管理&#xff08;Cookies&Session&#xff09; 第八章 文件上传…

如何检查我的网站是否支持HTTPS

HTTPS是一种用于安全通信的协议&#xff0c;是HTTP的安全版本。HTTPS的主要作用在于为互联网上的数据传输提供安全性和隐私保护。通常是需要在网站安装部署SSL证书来实现网络数据加密传输&#xff0c;安全加密功能。 那么如果要检查你的网站是否支持HTTPS&#xff0c;可以看下…

培训第十一天(nfs与samba共享文件)

上午 1、环境准备 &#xff08;1&#xff09;yum源 &#xff08;一个云仓库pepl仓库&#xff09; [rootweb ~]# vim /etc/yum.repos.d/hh.repo [a]nameabaseurlfile:///mntgpgcheck0[rootweb ~]# vim /etc/fstab /dev/cdrom /mnt iso9660 defaults 0 0[rootweb ~]# mount -a[…

软件测试09 自动化测试技术(Selenium)

重点/难点 重点&#xff1a;理解自动化测试的原理及其流程难点&#xff1a;Selinum自动化测试工具的使用 目录 系统测试 什么是系统测试什么是功能测试什么是性能测试常见的性能指标有哪些 自动化测试概述 测试面临的问题 测试用例数量增多&#xff0c;工作量增大&#xff…

数据结构初阶(C语言)-二叉树

一&#xff0c;树的概念与结构 树是⼀种非线性的数据结构&#xff0c;它是由 n&#xff08;n>0&#xff09; 个有限结点组成⼀个具有层次关系的集合。把它叫做树是因为它看起来像⼀棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的。 1.有⼀个特殊的结点&a…

ubuntu22安装拼音输入法

专栏总目录 一、安装命令&#xff1a; sudo apt update sudo apt install fcitx sudo apt install fcitx-pinyin 二、切换输入法

吴恩达深度学习笔记1 Neural Networks and Deep Learning

参考视频&#xff1a;(超爽中英!) 2024公认最好的【吴恩达深度学习】教程&#xff01;附课件代码 Professionalization of Deep Learning_哔哩哔哩_bilibili Neural Networks and Deep Learning 1. 深度学习引言(Introduction to Deep Learning) 2. 神 经 网 络 的 编 程 基 础…

数据库安全:MySQL安全配置,MySQL安全基线检查加固

「作者简介」&#xff1a;冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础著作 《网络安全自学教程》&#xff0c;适合基础薄弱的同学系统化的学习网络安全&#xff0c;用最短的时间掌握最核心的技术。 这一章节我们需…

【目标检测】Anaconda+PyTorch(GPU)+PyCharm(Yolo5)配置

前言 本文主要介绍在windows系统上的Anaconda、PyTorch、PyCharm、Yolov5关键步骤安装&#xff0c;为使用yolo所需的环境配置完善。同时也算是记录下我的配置流程&#xff0c;为以后用到的时候能笔记查阅。 Anaconda 软件安装 Anaconda官网&#xff1a;https://www.anaconda…

微软蓝屏事件:网络安全与系统稳定性的反思与前瞻

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

MFC:以消息为基础的事件驱动系统和消息映射机制

以消息为基础的事件驱动系统和消息映射机制 (1)消息 A.What&#xff08;什么是消息&#xff09; 本质是一个数据结构&#xff0c;用于应用程序不同部分之间进行通信和交互 typedef struct tagMSG {HWND hwnd; // 接收该消息的窗口句柄UINT message; // 消息标…

二分查找的实现

前提&#xff1a;数组是有序的 #include <stdio.h>//作用&#xff1a;利用二分查找法查找数据 //返回值&#xff1a;数据在数组中的索引 //找到了&#xff1a;真实索引 没找到&#xff1a;返回-1 int search(int arr[], int num, int len) {//查找范围int min 0;int …