【深度学习02】注意力机制

news2024/11/24 3:05:58

1.自注意力机制

自注意力机制(Self-Attention Mechanism)是深度学习中的一种方法,广泛应用于自然语言处理和其他领域。为了更好地理解它,可以用一个简单的类比来解释。

类比:学生在课堂上做笔记

假设你是一个学生,正在听老师讲课,你需要做笔记。你要注意老师现在说的内容,并且回想之前老师讲过的内容来理解当前的信息。这个过程就类似于自注意力机制。

符号和原理

在自注意力机制中,有几个关键的符号和步骤:

  1. 输入向量(X)

    • 类比:老师在某一时刻说的一句话(当前的讲课内容)。
    • 输入向量就是你当前需要处理的信息。
  2. 查询(Query,Q)

    • 类比:你脑中用来提问和寻找信息的一个参考点,比如“老师现在说的和之前哪些内容相关?”。
    • 查询向量是用来找到与当前信息相关的其他信息的工具。
  3. 键(Key,K)

    • 类比:你脑中储存的所有之前听到的内容(所有你已经记住的信息的索引)。
    • 键向量是用来匹配查询向量的索引。
  4. 值(Value,V)

    • 类比:所有你已经记住的具体内容(所有详细信息)。
    • 值向量是实际的内容,存储了所有你需要记住的具体信息。

计算过程

  1. 计算相似度(Attention Score)

    • 类比:你用查询向量(当前问题)去和所有键向量(之前的索引)做对比,找到哪些之前的信息与当前内容相关。
    • 在数学上,这是通过点积来计算查询和键之间的相似度。
  2. 权重(Attention Weights)

    • 类比:根据相似度,你分配不同的注意力(权重)给相关的信息。例如,之前某句话和当前内容非常相关,那么你会给它更多的注意力。
    • 权重是通过相似度计算后归一化(通常用softmax函数)得到的。
  3. 加权求和(Weighted Sum)

    • 类比:根据分配的权重,你整合所有相关的信息,形成一个综合的理解。例如,你把所有相关的笔记内容加权求和,得到一个新的理解。
    • 最终的输出是所有值向量根据权重加权求和的结果。

实际意义

自注意力机制的实际意义在于它能够让模型在处理每个输入时,同时关注所有其他的输入。这样,模型不仅能处理当前的信息,还能结合上下文进行理解。例如,在翻译一段文本时,自注意力机制可以让模型在翻译每个词时,同时参考整段文本,从而生成更准确的译文。

总结

自注意力机制可以理解为学生在课堂上做笔记时,不仅关注当前的讲课内容,还结合之前的所有笔记来理解新的信息。通过查询、键和值向量的计算,模型能够灵活地整合和理解复杂的信息。

2.训练流程

权重矩阵的训练过程可以通过以下步骤来理解和实现。为了便于理解,我们可以用一个简单的例子来说明这个过程。

简单的类比:学习新单词

假设你在学习一门新语言,每天都会学到新的单词。你希望记住这些单词,并能够在不同的句子中正确使用它们。这个过程类似于训练神经网络中的权重矩阵。

训练权重矩阵的步骤

  1. 初始化权重矩阵

    • 在一开始,权重矩阵中的值是随机初始化的。就像你第一次听到一个新单词时,你对它的理解是模糊的。
  2. 输入数据(Input Data)

    • 你有一组输入数据,比如句子或单词列表。对于神经网络来说,输入数据是特征向量。
  3. 前向传播(Forward Propagation)

    • 类比:你读了一句话并试图理解它。
    • 计算输入数据通过网络各层的输出,直到得到最终输出。在自注意力机制中,这包括计算查询、键和值向量,以及它们之间的点积和加权求和。
  4. 计算损失(Loss Calculation)

    • 类比:你回顾这句话的含义,发现自己对某些单词的理解有误。
    • 计算网络输出与实际目标之间的误差,这称为损失。常见的损失函数有均方误差(MSE)、交叉熵损失等。
  5. 反向传播(Backpropagation)

    • 类比:你通过反复记忆和练习,纠正对单词的理解。
    • 计算损失相对于网络中每个权重的梯度。这一步通过链式法则来实现,逐层计算误差的传播。
  6. 更新权重(Weight Update)

    • 类比:你根据自己的错误,调整对单词的记忆。
    • 使用优化算法(如梯度下降)根据计算出的梯度调整权重。更新公式一般为:
      W new = W old − η ⋅ ∇ L W_{\text{new}} = W_{\text{old}} - \eta \cdot \nabla L Wnew=WoldηL
      其中, W new W_{\text{new}} Wnew 是更新后的权重, W old W_{\text{old}} Wold 是当前权重, η \eta η 是学习率, ∇ L \nabla L L 是损失函数对权重的梯度。
  7. 重复训练(Iterative Training)

    • 类比:你不断重复学习和记忆过程,直到能够熟练掌握新单词。
    • 这个过程会重复多次,通常使用多个训练样本进行多个迭代(epochs),直到模型的性能达到满意的水平。

示例:自注意力机制中的权重矩阵训练

假设我们有一个简单的输入句子“机器学习很有趣”。

  1. 输入向量

    • 每个词会被转换为一个向量,比如通过词嵌入(word embeddings)。
  2. 计算查询(Q)、键(K)和值(V)向量

    • 对每个输入词向量进行线性变换,得到 Q、K、V 向量:
      Q = X W Q , K = X W K , V = X W V Q = XW_Q, \quad K = XW_K, \quad V = XW_V Q=XWQ,K=XWK,V=XWV
      其中 X X X 是输入词向量, W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 是需要训练的权重矩阵。
  3. 计算注意力得分(Attention Scores)

    • 计算查询和键的点积,并进行缩放和归一化:
      Scores = softmax ( Q K T d k ) \text{Scores} = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) Scores=softmax(dk QKT)
  4. 计算加权求和(Weighted Sum)

    • 用得分矩阵对值向量进行加权求和:
      Output = Scores ⋅ V \text{Output} = \text{Scores} \cdot V Output=ScoresV
  5. 计算损失并反向传播

    • 根据模型的输出和目标进行损失计算,并通过反向传播更新 W Q , W K , W V W_Q, W_K, W_V WQ,WK,WV 等权重矩阵。
  6. 优化器更新权重

    • 使用优化器(如Adam、SGD)更新权重矩阵。

通过反复的前向传播、损失计算、反向传播和权重更新,模型逐渐学会在不同的上下文中正确使用输入信息,从而提升其表现。


3.多头注意力机制

多头注意力机制(Multi-Head Attention)是自注意力机制的扩展和增强版本,它在Transformer模型中扮演了重要角色。多头注意力通过将注意力机制并行执行多次,使模型能够捕捉到不同的上下文信息和特征。让我们通过一个详细的解释和类比来理解多头注意力。

类比:多个侦探同时调查案件

假设你是一名侦探,正在调查一个复杂的案件。为了更好地收集线索,你邀请了几个同事(其他侦探)一起来帮忙。每个侦探都有不同的视角和擅长领域,他们会独立地调查,然后将各自的发现汇总。这类似于多头注意力机制。

多头注意力的步骤

  1. 输入分割

    • 类比:每个侦探从相同的信息开始,但各自独立调查。
    • 输入向量被分成多个子向量,每个子向量对应一个注意力头(Head)。
  2. 独立计算注意力

    • 类比:每个侦探独立地分析他们分配到的信息。
    • 每个头独立地执行自注意力机制,计算查询(Q)、键(K)和值(V)向量,生成不同的注意力输出。
  3. 汇总结果

    • 类比:每个侦探将他们的发现汇总。
    • 各个头的输出被连接(Concat),然后通过一个线性变换得到最终的输出。

数学表示

设有一个输入矩阵 X X X,其维度为 ( N , T , D ) (N, T, D) (N,T,D),其中 N N N 是批量大小(batch size), T T T 是序列长度, D D D 是特征维度。多头注意力机制可以表示为:

  1. 线性变换
    对于每个头 ( i ),我们有不同的线性变换矩阵 W i Q W_i^Q WiQ W i K W_i^K WiK W i V W_i^V WiV,将输入矩阵 X X X 转换为查询、键和值向量:
    Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = XW_i^Q, \quad K_i = XW_i^K, \quad V_i = XW_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV

  2. 计算注意力得分
    计算查询和键的点积,并进行缩放和归一化:
    Attention i = softmax ( Q i K i T d k ) V i \text{Attention}_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i Attentioni=softmax(dk QiKiT)Vi
    其中 d k d_k dk 是键向量的维度。

  3. 连接和线性变换
    将所有头的输出连接在一起,形成一个新的矩阵,然后通过一个线性变换:
    MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO
    其中 W O W^O WO 是一个线性变换矩阵,用于整合多头的输出。

代码示例

import torch
import torch.nn as nn
import torch.optim as optim

# 定义自注意力机制
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        assert self.head_dim * heads == embed_size, "Embedding size needs to be divisible by heads"
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
    
    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        # Compute the dot product attention scores
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        
        out = self.fc_out(out)
        return out

# 定义一个简单的模型
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

# 数据准备
sentence = "machine learning is fun"
vocab = list(set(sentence.split()))
vocab_size = len(vocab)
embed_size = 8
heads = 2
dropout = 0.2
forward_expansion = 4

word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

# 生成简单的词嵌入
embedding = nn.Embedding(vocab_size, embed_size)
input_indices = torch.tensor([word_to_idx[word] for word in sentence.split()]).unsqueeze(0)
input_embeds = embedding(input_indices)

# 定义模型、损失函数和优化器
model = TransformerBlock(embed_size, heads, dropout, forward_expansion)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 简单的目标:输出应与输入相同(自监督学习任务)
target_embeds = input_embeds.clone()

# 训练模型
num_epochs = 1000

for epoch in range(num_epochs):
    model.train() # Sets the module in training mode
    optimizer.zero_grad()
    
    # 前向传播
    output = model(input_embeds, input_embeds, input_embeds, mask=None)
    
    # 计算损失
    loss = criterion(output, target_embeds)
    
    # 反向传播和优化
    loss.backward(retain_graph=True)
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training complete.")


import numpy as np
# 获取最近邻词索引
def get_nearest_neighbors(embeddings, target_embeds):
    embeddings = embeddings.weight.data.cpu().numpy()
    target_embeds = target_embeds.cpu().detach().numpy()
    nearest_neighbors = []
    
    for target in target_embeds[0]:
        distances = np.linalg.norm(embeddings - target, axis=1)
        nearest_idx = np.argmin(distances)
        nearest_neighbors.append(nearest_idx)
    
    return nearest_neighbors

# 将输出嵌入转换回词索引
nearest_neighbors = get_nearest_neighbors(embedding, output)

# 将词索引转换回单词
output_sentence = " ".join([idx_to_word[idx] for idx in nearest_neighbors])
print("Output Sentence:", output_sentence)

报错解决:http://t.csdnimg.cn/dKpGY

4.框架流程图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1690105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

合约开发的基本结构剖析及前置知识梳理

前置知识点 上下文变量初步 合约函数的背后是transaction,上下文变量访问的是transaction中的信息两个上下文变量:tx和msg ERC20 规范代码实现Metamask测试 ganache-cli的安装 安装 npm install -g ganache-cli启动 ganache-cli如果出现以下这种…

ZooKeeper系列之ZAB协议

概述 ZooKeeper Atomic Broadcast,ZooKeeper原子消息广播协议。ZAB协议是为分布式协调服务ZK专门设计的一种支持崩溃恢复的原子广播协议。ZK主要依赖ZAB协议来实现分布式数据的最终一致性,基于该协议,ZK实现一种主备模式的系统架构来保持集群…

【EI会议】2024年雷达、电子与通信工程国际会议(ICREC 2024)

2024年雷达、电子与通信工程国际会议 2024 International Conference on Radar, Electronics and Communication Engineering 【1】会议简介 2024年雷达、电子与通信工程国际会议即将在深圳隆重召开。深圳,这座充满活力的现代化都市,以其卓越的科技创新…

后端之路第二站(正片)——SprintBoot之:设置请求接口

这一篇讲怎么简单结合模拟云接口,尝试简单的后端接接口、接受并传数据 一、下载Apifox接口文档软件 目前的企业都是采用前后端分离开发的,在开发阶段前后端需要统一发送请求的接口,前端也需要在等待后端把数据存到数据库之前,自己…

微信H5跳小程序 wx-open-launch-weapp ios显示且正常跳转,安卓不显示不报错解决方案

前提:在一切都正常(无报错,没有写法错误等)的情况下,出现这个问题: 去你的h5项目,用浏览器打开,在network随便找一个静态文件,在response响应标头中找找,是否有Content-Security-Policy这个头&…

vue2流星雨(可调角度)

新建StarBackground.vue组件 打开组件注释部分可以随机颜色 <template><div class"rain"><divv-for"(item,index) in rainNumber":key"index"class"rain-item"ref"rain-item":style"transform:rotate(…

【MySQL进阶之路 | 基础篇】触发器

1. 为什么要使用触发器 我们可能会遇到如下场景.我们有两个相互关联的表&#xff0c;如商品信息表与库存信息表.当我们向商品信息表添加一条记录时&#xff0c;为了保证数据完整性&#xff0c;也必须向库存信息表添加一条数据.我们就必须把这两个关联的操作写在程序里&#xf…

【APKtool】APKtool实现某瓣APP重签名

APP name 重打包 重打包完成 开始签名 apktool签名 使用 APKtool 或其他工具生成的签名文件与原始签名文件的区别主要在于它们使用的密钥和证书可能不同。当你使用 APKtool 对 APK 文件进行反编译、修改后再重新打包时&#xff0c;你通常需要使用一个新的密钥和证书对修改后…

机器人非线性控制方法——线性化与解耦

机器人非线性控制方法是针对具有非线性特性的机器人系统所设计的一系列控制策略。其中&#xff0c;精确线性化控制和反演控制是两种重要的方法。 1. 非线性反馈控制 该控制律采用非线性反馈控制的方法&#xff0c;将控制输入 u 分解为两个部分&#xff1a; α(x): 这是一个与…

计算机毕业设计 | springboot养老院管理系统 老人社区管理(附源码)

1&#xff0c;绪论 1.1 背景调研 养老院是集医疗、护理、康复、膳食、社工等服务服务于一体的综合行养老院&#xff0c;经过我们前期的调查&#xff0c;院方大部分工作采用手工操作方式,会带来工作效率过低&#xff0c;运营成本过大的问题。 院方可用合理的较少投入取得更好…

HTML5 + CSS3模拟庆余年中“五竹”的镭射眼动画特效

庆余年2已经火热开播了&#xff0c;据说反响强烈啊&#xff0c;不知道这一部里面&#xff0c;五竹的镭射眼会不会表现出来&#xff0c;我还挺想看看他的镭射眼的&#xff0c;我看到底有没有杀死剧中的庆帝。 回想第一部&#xff0c;我都快记不清那是几年前开播的了&#xff0c;…

Ubuntu 安装 LibreOffice

1. 删除预安装的LibreOffice Ubuntu 和其他的 Linux 发行版带有预安装的 LibreOffice。这可能不是最新的&#xff0c;这是因为发行版有特定的发行周期。在进行新安装之前&#xff0c;你可以通过以下命令删除 Ubuntu 及其衍生发行版中的的旧版本。 sudo apt remove –purge li…

VScode SSH连接远程服务器报错

一、报错 通过VScode SSH插件远程连接服务器&#xff0c;输入密码后没有连接成功&#xff0c;一直跳出输入密码界面&#xff0c;在输出界面里&#xff0c;一直是Waiting for server log或者是显示Cannot not find minimist 二、处理 &#x1f431;&#xff1a; 这个时候应该…

安全工程师考试摸拟试题

安全工程师考试摸拟试题安全工程师是指在工程项目中负责安全管理和安全技术服务的专业人员。他们需要具备扎实的理论知识和丰富的实践经验&#xff0c;能够有效预防和控制各类安全风险… 1 安全工程师考试摸拟试题 安全工程师是指在工程项目中负责安全管理和安全技术服务的专业…

Vue开发实例(十三)用户登录功能

使用Vue实现登录具有以下几个好处&#xff1a; 响应式界面&#xff1a;Vue框架的响应式特性可以帮助开发者轻松地实现用户登录界面的交互效果&#xff0c;包括表单验证、实时错误提示等&#xff0c;从而提升用户体验。组件化开发&#xff1a;Vue框架支持组件化开发&#xff0c;…

pillow学习3

Pillow库中&#xff0c;图像的模式代表了图像的颜色空间。以下是一些常见的图像模式及其含义&#xff1a; L&#xff08;灰度图&#xff09;&#xff1a;L模式表示图像是灰度图像&#xff0c;每个像素用8位表示&#xff08;范围为0-255&#xff09;&#xff0c;0表示黑色&#…

国家开放大学-实验3:类、对象、方法和修饰符的使用

作业答案 联系QQ:1603277115 实验目的 通过本实验&#xff0c;了解和掌握类、方法以及各个修饰符的使用。 问题描述 基于面向对象思想和类的方式&#xff0c;创建一个计算金额的程序。 啤酒 3.5元/罐&#xff0c; 方便面 4.5元/包&#xff0c; 矿泉水 2.0 元/瓶。 优惠规…

【Linux】信号之信号的产生详解

&#x1f916;个人主页&#xff1a;晚风相伴-CSDN博客 &#x1f496;如果觉得内容对你有帮助的话&#xff0c;还请给博主一键三连&#xff08;点赞&#x1f49c;、收藏&#x1f9e1;、关注&#x1f49a;&#xff09;吧 &#x1f64f;如果内容有误的话&#xff0c;还望指出&…

Java入门基础学习笔记50——ATM系统

1、项目演示&#xff1b; 2、项目技术实现&#xff1b; 1&#xff09;面向对象编程&#xff1a; 每个账户都是一个对象&#xff0c;所以要设计账户类Account&#xff0c;用于创建账户对象封装账户信息。ATM同样是一个对象&#xff0c;需要设计ATM类&#xff0c;代表ATM管理系…

打破壁垒,实现多引擎3D内容轻量化交付|点量云流

随着应用场景的不断拓展&#xff0c;传统的视频流技术已难以满足日益复杂的需求。当前市场上的视频流解决方案支持的引擎基本是UE、Unitiy输出的exe3D应用&#xff0c;在处理WebGL等3D内容时&#xff0c;也存在诸多局限性&#xff0c;例如性能限制、跨平台兼容性问题、无法直接…