Transformer自注意力机制(Self-Attention)模型

news2024/10/5 14:09:21


上一篇我们介绍了transform专题一:Seq2seq model,也知道了transfrom属于seq2seq模型,这一排篇咱们接着介绍另外几种seq2seq架构的模型。)RNN(循环神经网络)CNN(卷积神经网络),最后我们还会介绍一种transformer采用的self-attention模型

RNN(Recurrent Neural Network)和CNN(Convolutional Neural Network)是两种不同类型的神经网络架构,分别适用于不同类型的数据和任务。下面我们来详细介绍这两种网络的原理、特点和应用场景。

RNN(循环神经网络)

原理

RNN是一类用于处理序列数据的神经网络,其结构允许信息在序列的各个时间步之间传递。与传统的前馈神经网络不同,RNN具有循环连接,使其能够处理时间序列数据或其他顺序相关的数据。

RNN的基本单元是一个神经元,它在每个时间步都接收当前输入和前一时间步的隐藏状态,然后计算出当前的隐藏状态。这个隐藏状态不仅包含当前输入的信息,还包含前一时间步的信息。

数学表达

RNN的核心公式如下:

[ h_t = \sigma(W_h \cdot h_{t-1} + W_x \cdot x_t + b) ]

其中:

  • ( h_t ) 是当前时间步的隐藏状态
  • ( h_{t-1} ) 是前一时间步的隐藏状态
  • ( x_t ) 是当前时间步的输入
  • ( W_h ) 和 ( W_x ) 是权重矩阵
  • ( b ) 是偏置向量
  • ( \sigma ) 是激活函数(如tanh或ReLU)

特点

  1. 处理序列数据:RNN擅长处理和预测序列数据,能够捕捉时间步之间的依赖关系。
  2. 记忆能力:通过循环连接,RNN能够记忆序列中的信息,适用于需要上下文信息的任务。
  3. 梯度消失和爆炸:RNN在处理长序列时可能会遇到梯度消失或梯度爆炸问题,影响训练效果。

变种

为了克服RNN的缺点,研究人员提出了多种RNN的变种:

  1. LSTM(长短期记忆网络):通过引入门机制(输入门、遗忘门和输出门),LSTM能够更好地捕捉长距离依赖关系。
  2. GRU(门控循环单元):GRU是LSTM的简化版,通过减少门的数量,保留了LSTM的优势,同时提高了计算效率。

应用场景

  1. 自然语言处理(NLP):如机器翻译、文本生成、语音识别等。
  2. 时间序列预测:如股票价格预测、天气预报等。
  3. 语音和视频处理:如语音识别、视频分析等。

CNN(卷积神经网络)

原理

CNN是一类专门用于处理图像数据的神经网络,通过卷积操作提取图像的空间特征。CNN利用局部连接和共享权重的机制,能够有效地处理高维图像数据。

CNN的基本组成部分包括卷积层(Convolutional Layer)、池化层(Pooling Layer)和全连接层(Fully Connected Layer)。

卷积层

卷积层通过卷积核(也称为滤波器)在输入图像上进行滑动窗口操作,提取局部特征。每个卷积核可以检测不同的特征,如边缘、角点等。

卷积操作的公式如下:

[ (I * K)(i, j) = \sum_m \sum_n I(i+m, j+n) \cdot K(m, n) ]

其中:

  • ( I ) 是输入图像
  • ( K ) 是卷积核
  • ( (i, j) ) 是卷积操作的位置
池化层

池化层通过对卷积层的输出进行下采样,减少数据量和计算量,同时保留重要的特征。常见的池化操作包括最大池化(Max Pooling)和平均池化(Average Pooling)。

全连接层

全连接层将前面的特征图展平,并进行分类或回归任务。全连接层与传统的前馈神经网络类似,所有的输入节点与输出节点相连。

特点

  1. 局部连接:卷积操作只在局部区域内进行,有效减少了参数数量。
  2. 共享权重:同一卷积核在整个图像上共享参数,减少了过拟合的风险。
  3. 平移不变性:卷积操作能够有效识别图像中的特征位置,使模型对图像的平移具有鲁棒性。

应用场景

  1. 图像分类:如手写数字识别(MNIST)、物体识别(ImageNet)等。
  2. 目标检测:如人脸检测、车辆检测等。
  3. 图像分割:如语义分割、实例分割等。
  4. 医学影像分析:如病灶检测、图像重建等。

RNN与CNN的对比

特点RNNCNN
主要应用领域序列数据处理,如NLP、时间序列预测图像数据处理,如图像分类、目标检测
数据类型时间序列、文本、语音等图像、视频等
核心操作循环连接卷积操作
处理长序列的能力可能遇到梯度消失或爆炸问题不适用于处理长序列
参数量参数量较多,尤其是长序列通过局部连接和共享权重减少参数量
平移不变性不具备平移不变性具有平移不变性
记忆能力能够记忆和处理序列中的上下文信息主要用于捕捉图像中的局部特征
变种LSTM、GRU卷积神经网络变种,如ResNet、DenseNet

总之,RNN和CNN是两种重要的神经网络架构,分别擅长处理不同类型的数据。RNN适用于处理序列数据,能够捕捉时间步之间的依赖关系,而CNN则适用于处理图像数据,能够有效提取图像的空间特征。在实际应用中,选择哪种网络架构取决于具体的任务和数据类型。

自注意力机制(Self-Attention)是Transformer模型的核心部分,它在自然语言处理(NLP)任务中取得了显著的成功。与传统的RNN和CNN不同,自注意力机制能够在序列的所有位置之间直接建立联系,使模型能够捕捉长距离的依赖关系。以下是对自注意力机制的详细介绍。

自注意力机制的基本原理

自注意力机制通过对输入序列中的每个元素计算其与其他元素的相关性(注意力分数),从而生成新的表示。这个过程可以分为以下几个步骤:

1. 计算查询(Query)、键(Key)和值(Value)

对于输入序列中的每个元素,我们首先计算三个向量:查询向量 (Q)、键向量 (K) 和值向量 (V)。这些向量通过与可训练的权重矩阵进行线性变换得到:

[ Q = XW_Q ]
[ K = XW_K ]
[ V = XW_V ]

其中,(X) 是输入序列,(W_Q)、(W_K) 和 (W_V) 是可训练的权重矩阵。

2. 计算注意力分数

接下来,我们计算查询向量 (Q) 与键向量 (K) 的点积,并进行缩放(除以 (\sqrt{d_k})),然后通过Softmax函数得到注意力分数:

[ \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V ]

其中,(\sqrt{d_k}) 是缩放因子, (d_k) 是键向量的维度。

3. 计算加权和

使用注意力分数对值向量 (V) 进行加权求和,得到最终的输出向量。

4. 多头注意力机制

为了使模型能够捕捉更多的特征,自注意力机制通常会使用多头注意力(Multi-Head Attention)。多头注意力机制将查询、键和值向量分为多个子空间,并在每个子空间上独立计算注意力,最后将这些结果拼接在一起:

[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W_O ]

其中,每个头的计算方式为:

[ \text{head}i = \text{Attention}(QW{Q_i}, KW_{K_i}, VW_{V_i}) ]

自注意力机制的优势

  1. 并行化:与RNN不同,自注意力机制可以在计算时并行处理输入序列中的所有位置,提高了训练效率。
  2. 长距离依赖:自注意力机制能够直接在序列的所有位置之间建立联系,更好地捕捉长距离依赖关系。
  3. 灵活性:自注意力机制可以用于各种类型的输入数据,不仅限于时间序列。

Transformer模型

自注意力机制是Transformer模型的核心组件。Transformer模型由编码器和解码器组成,每个编码器和解码器包含多个堆叠的自注意力层和前馈神经网络(Feed-Forward Neural Network)层。

编码器(Encoder)

每个编码器层包含两个主要部分:

  1. 多头自注意力层(Multi-Head Self-Attention Layer):计算输入序列中每个元素之间的注意力分数。
  2. 前馈神经网络层(Feed-Forward Neural Network Layer):对自注意力层的输出进行进一步处理。通常包含两个全连接层和一个ReLU激活函数。

此外,每个层还有残差连接(Residual Connection)和层归一化(Layer Normalization),以提高训练稳定性和模型性能。

解码器(Decoder)

解码器的结构与编码器类似,但包含额外的注意力层,用于计算解码器输入与编码器输出之间的注意力分数。每个解码器层包含三个主要部分:

  1. 掩码多头自注意力层(Masked Multi-Head Self-Attention Layer):类似于编码器的自注意力层,但对未来的时间步进行掩码,以防止信息泄露。
  2. 多头注意力层(Multi-Head Attention Layer):计算解码器输入与编码器输出之间的注意力分数。
  3. 前馈神经网络层(Feed-Forward Neural Network Layer):与编码器中的前馈神经网络层类似。

Transformer模型的应用

Transformer模型在各种NLP任务中取得了显著的成功,包括:

  1. 机器翻译:如Google的翻译服务,使用Transformer模型实现高质量的翻译。
  2. 文本生成:如OpenAI的GPT系列模型,可以生成连贯且有意义的文本。
  3. 文本分类:如BERT模型,通过预训练和微调,实现了各种文本分类任务的高性能。
  4. 问答系统:如BERT和GPT,可以用于构建智能问答系统,回答用户的问题。

代码示例

以下是使用PyTorch实现自注意力机制的简化示例:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.embed_size ** (1 / 2))

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy, dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

# 示例输入
embed_size = 256
heads = 8
seq_length = 10
batch_size = 32

values = torch.randn(batch_size, seq_length, embed_size)
keys = torch.randn(batch_size, seq_length, embed_size)
query = torch.randn(batch_size, seq_length, embed_size)
mask = None

attention = SelfAttention(embed_size, heads)
out = attention(values, keys, query, mask)
print(out.shape)  # 输出的形状应该是 (batch_size, seq_length, embed_size)

总结

自注意力机制是现代自然语言处理模型(如Transformer)的核心,具有处理并行化、捕捉长距离依赖和灵活性强等优点。通过多头注意力机制,自注意力模型能够有效地捕捉序列中的多种特征,广泛应用于机器翻译、文本生成、文本分类和问答系统等任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1901286.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

图像处理调试软件推荐

对于图像处理的调试,使用具有图形用户界面(GUI)且支持实时调整和预览的图像处理软件,可以大大提高工作效率。以下是几款常用且功能强大的图像处理调试软件推荐: ImageJ/FijiMATLABOpenCV with GUI LibrariesNI Vision …

Stable Diffusion:最全详细图解

Stable Diffusion,作为一种革命性的图像生成模型,自发布以来便因其卓越的生成质量和高效的计算性能而受到广泛关注。不同于以往的生成模型,Stable Diffusion在生成图像的过程中,采用了独特的扩散过程,结合深度学习技术…

WAIC:生成式 AI 时代的到来,高通创新未来!

目录 01 在终端侧算力上,动作最快的就是高通 02 模型优化,完成最后一块拼图 在WAIC上,高通展示的生成式AI创新让我们看到了未来的曙光。 生成式 AI 的爆发带来了意想不到的产业格局变化,其速度之快令人惊叹。 仅在一个月前&…

考研必备~总结严蔚敏教授《数据结构》课程的重要知识点及考点

作者主页:知孤云出岫 目录 1. 基本概念1.1 数据结构的定义1.2 抽象数据类型 (ADT) 2. 线性表2.1 顺序表2.2 链表 3. 栈和队列3.1 栈3.2 队列 4. 树和二叉树4.1 树的基本概念4.2 二叉树 5. 图5.1 图的基本概念5.2 图的遍历 6. 查找和排序6.1 查找6.2 排序 7. 重点考…

[图解]SysML和EA建模住宅安全系统-11-接口块

1 00:00:00,660 --> 00:00:04,480 接下来的步骤是定义系统上下文 2 00:00:04,960 --> 00:00:07,750 首先是图17.17 3 00:00:09,000 --> 00:00:10,510 系统上下文展示了 4 00:00:10,520 --> 00:00:12,510 ESS和外部系统、用户 5 00:00:12,520 --> 00:00:14,1…

简介时间复杂度

好了,今天我们来了解一下,我们在做练习题中常出现的一个名词。时间复杂度。我相信大家如果有在练习过题目的话。对这个名词应该都不陌生吧。但是可能很少的去思考它是干什么的代表的什么意思。反正我以前练习的时候就是这样。我只知道有这么一个名词在题…

DevOps实战:使用GitLab+Jenkins+Kubernetes(k8s)建立CI_CD解决方案

一.系统环境 本文主要基于Kubernetes1.21.9和Linux操作系统CentOS7.4。 服务器版本docker软件版本Kubernetes(k8s)集群版本CPU架构CentOS Linux release 7.4.1708 (Core)Docker version 20.10.12v1.21.9x86_64CI/CD解决方案架构图:CI/CD解决方案架构图描述:程序员写好代码之…

测试几个 ocr 对日语的识别情况

测试几个 ocr 对日语的识别情况 1. EasyOCR2. PaddleOCR3. Deepdoc(识别pdf中图片)4. Deepdoc(识别pdf中文字)5. Nvidia neva-22b6. Claude 3.5 sonnet 识别图片中的文字7. Claude 3.5 sonnet 识别 pdf 中表格8. OpenAI gpt-4o 识…

CMS Made Simple v2.2.15 远程命令执行漏洞(CVE-2022-23906)

前言 CVE-2022-23906 是一个远程命令执行(RCE)漏洞,存在于 CMS Made Simple v2.2.15 中。该漏洞通过上传头像功能进行利用,攻击者可以上传一个经过特殊构造的图片文件来触发漏洞。 漏洞详情 CMS Made Simple v2.2.15 中的头像上…

verilog读写文件注意事项

想要的16进制数是文本格式提供的文件,想将16进制数提取到变量内, 可以使用 f s c a n f ( f d 1 , " 也可以使用 fscanf(fd1,"%h",rd_byte);实现 也可以使用 fscanf(fd1,"也可以使用readmemh(“./FILE/1.txt”,mem);//fe放在mem[0…

alphazero学习

AlphaGoZero是AlphaGo算法的升级版本。不需要像训练AlphaGo那样,不需要用人类棋局这些先验知识训练,用MCTS自我博弈产生实时动态产生训练样本。用MCTS来创建训练集,然后训练nnet建模的策略网络和价值网络。就是用MCTSPlayer产生的数据来训练和…

VRPTW(MATLAB):常春藤算法(IVY)求解带时间窗的车辆路径问题VRPTW,MATLAB代码

详细介绍 VRPTW(MATLAB):常春藤算法(Ivy algorithm,IVY)求解带时间窗的车辆路径问题VRPTW(提供MATLAB代码)-CSDN博客 ********************************求解结果******************…

web零碎知识2

不知道我的这个axios的包导进去没。 找一下关键词: http请求协议:就是进行交互式的格式 需要定义好 这个式一发一收短连接 而且没有记忆 这个分为三个部分 第一个式请求行,第二个就是请求头 第三个就是请求体 以get方式进行请求的失手请求…

C语言 -- 深入理解指针(一)

C语言 -- 深入理解指针(一) 1.内存和地址1.1 内存1.2 究竟该如何理解编址 2. 指针变量和地址2.1 取地址操作符(&)​2.2 指针变量和解引用操作符(*)​​2.2.1 指针变量2.2.2 如何拆解指针类型2.2.3 解引…

Java语言程序设计基础篇(第10版)编程练习题13.18(使用 Rational 类)

第十三章第十八题(使用 Rational 类) 题目要求: 编写程序,使用 Rational 类计算下面的求和数列: 你将会发现输出是不正确的 ,因为整数溢出(太大了)。为了解决这个问题 ,参见编程练习題13.15。代码参考: package cha…

羊大师:小暑至,热浪涌,三伏悠长防暑忙

随着夏日的脚步悄然加速,我们迎来了小暑节气。小暑,一个预示着盛夏正式拉开序幕的时节,它携带着滚滚热浪,让大地仿佛置身于火炉之中。而随之而来的三伏天,更是长达40天的酷热考验,让人不禁感叹夏日的漫长与…

文件、文本阅读与重定向、路径与理解指令——linux指令学习(一)

前言:本节内容标题虽然为指令,但是并不只是讲指令, 更多的是和指令相关的一些原理性的东西。 如果友友只想要查一查某个指令的用法, 很抱歉, 本节不是那种带有字典性质的文章。但是如果友友是想要来学习的,…

记录第一次使用air热更新golang项目

下载 go install github.com/cosmtrek/airlatest 下载时提示: module declares its path as: github.com/air-verse/air but was required as: github.com/cosmtrek/air 此时,需要在go.mod中加上这么一句: replace github.com/cosmtrek/air &…

VitePress美化

参考资料: https://blog.csdn.net/weixin_44803753/article/details/130903396 https://blog.csdn.net/qq_30678861/category_12467776.html 站点信息修改 首页部分的修改基本都在.vitepress/config.mts,这个文件内修改。 title 站点名称 description 描述 top…

轻松快速上手Thekey库,实现数据加密无忧

Thekey的概述: Thekey库是一个Python库,旨在简化数据加密、解密、签名和验证的过程。它提供了一套简洁易用的接口,用于处理各种加密任务,适合需要在应用程序中实现安全数据处理的开发人员. 安装Thekey库 pip install thekey使用Thekey库进行基本加密和解密操作的…