【自然语言处理】Encoder-Decoder模型中Attention机制的引入

news2024/11/24 5:20:25

Encoder-Decoder 模型中引入 Attention 机制,是为了改善基本Seq2Seq模型的性能,特别是当处理长序列时,传统的Encoder-Decoder模型容易面临信息压缩的困难。Attention机制可以帮助模型动态地选择源序列中相关的信息,从而提高翻译等任务的质量。

一、为什么需要Attention机制?

在基本的 Encoder-Decoder 模型中,Encoder将整个源句子的所有信息压缩成一个固定大小的向量(上下文向量),然后Decoder使用这个向量来生成目标序列。这个单一的上下文向量对于较短的句子可能足够,但对于较长的句子,模型可能无法有效捕捉到整个句子中所有重要的信息。这样容易导致信息丢失,尤其是当句子很长时,Decoder在生成目标词时可能无法获取到源句子的细节信息。

二、Attention机制的核心思想

Attention机制的核心思想是:在每个时间步生成目标单词时,Decoder不再依赖于固定的上下文向量,而是能够通过“注意力”权重,动态地从源句子的所有隐状态中选择最相关的部分。这样,Decoder每生成一个目标词时,能够更好地“关注”源句子中与当前生成词最相关的部分。

三、Attention机制的工作流程

在每一步解码时,Attention机制会根据Decoder的当前状态计算出一组权重,表示源句子中各个位置的隐状态对当前解码步骤的重要性。这些权重用于加权源句子的隐状态,以得到一个上下文向量,这个上下文向量会与当前Decoder的隐状态一起用于生成下一个目标词。由于它跨越两个序列:源语言序列(编码器输出)作为 Key 和 Value;目标语言序列(解码器的当前状态)作为 Query,因此也叫交叉注意力

Attention的具体步骤如下:

  1. 计算注意力权重

    • 对于Decoder的每一步(生成每个目标词时),通过Decoder的当前隐状态和源句子每个时间步的隐状态来计算注意力权重。
    • 这些权重表示源句子中每个位置的重要性,可以使用加性Attention点积Attention来计算。
  2. 计算上下文向量

    • 通过将注意力权重与源句子的隐状态进行加权平均,得到一个新的上下文向量。
    • 这个上下文向量包含了源句子中当前对Decoder最重要的信息。
  3. 解码下一步

    • 将新的上下文向量与当前Decoder的隐状态结合,用于生成当前的目标词。

四、Attention机制的公式

对于每个时间步 t:

  1. 计算注意力得分:通常使用Decoder当前的隐状态 ht 和源句子每个位置的隐状态 hs 计算注意力得分,可以通过以下公式计算:

在这里插入图片描述

常见的 score 函数有加性(Bahdanau Attention)和点积(Luong Attention):

  • 加性Attention:使用一个简单的前馈网络对 ht 和 hs 进行线性变换并加和。
  • 点积Attention:直接计算 ht 和 hs 的点积。
  1. 计算注意力权重:对得分 et,s​ 进行Softmax操作,得到权重:

在这里插入图片描述

这些权重 αt,s 表示源句子中各个位置对当前解码的影响力。

  1. 计算上下文向量:使用注意力权重对源句子的隐状态进行加权平均,得到上下文向量 ct:

在这里插入图片描述

  1. 生成下一个词:将上下文向量 ct 与Decoder的隐状态 ht 结合,生成下一个词。

五、引入Attention机制的Encoder-Decoder代码实现

以下是一个带有 Attention 机制的 Encoder-Decoder 模型的简化实现,使用 PyTorch 进行构建。

import torch
import torch.nn as nn

# Encoder模型
class Encoder(nn.Module):
    def __init__(self, input_size, embedding_dim, hidden_size):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)  # [batch_size, src_len, embedding_dim]
        outputs, (hidden, cell) = self.lstm(embedded)  # [batch_size, src_len, hidden_size]
        return outputs, hidden, cell

# Attention模型
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [batch_size, src_len, hidden_size]
        energy = torch.sum(self.v * energy, dim=2)  # [batch_size, src_len]
        return torch.softmax(energy, dim=1)  # [batch_size, src_len]

# Decoder模型
class Decoder(nn.Module):
    def __init__(self, output_size, embedding_dim, hidden_size):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + hidden_size, hidden_size, batch_first=True)
        self.fc_out = nn.Linear(hidden_size * 2, output_size)
        self.attention = Attention(hidden_size)

    def forward(self, input_token, hidden, cell, encoder_outputs):
        input_token = input_token.unsqueeze(1)  # [batch_size, 1]
        embedded = self.embedding(input_token)  # [batch_size, 1, embedding_dim]
        
        # 计算注意力权重
        attn_weights = self.attention(hidden[-1], encoder_outputs)  # [batch_size, src_len]
        
        # 使用注意力权重对encoder输出进行加权平均
        attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # [batch_size, 1, hidden_size]

        # 将注意力上下文向量和嵌入层输入拼接
        lstm_input = torch.cat((embedded, attn_applied), dim=2)  # [batch_size, 1, embedding_dim + hidden_size]
        
        # 通过LSTM
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))  # [batch_size, 1, hidden_size]

        # 生成最终输出
        output = torch.cat((output.squeeze(1), attn_applied.squeeze(1)), dim=1)  # [batch_size, hidden_size * 2]
        prediction = self.fc_out(output)  # [batch_size, output_size]

        return prediction, hidden, cell

# Seq2Seq模型
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size = tgt.shape[0]
        target_len = tgt.shape[1]
        target_vocab_size = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        encoder_outputs, hidden, cell = self.encoder(src)

        input_token = tgt[:, 0]

        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell, encoder_outputs)
            outputs[:, t, :] = output
            top1 = output.argmax(1)
            input_token = tgt[:, t] if torch.rand(1).item() < teacher_forcing_ratio else top1

        return outputs
代码说明:
  1. Encoder

    • 编码源句子,生成隐状态和输出序列。
    • 输出序列会在注意力机制中使用。
  2. Attention

    • Attention 模型根据当前隐状态和Encoder输出计算注意力权重。
  3. Decoder

    • 使用Attention得到的注意力权重对Encoder输出进行加权平均,得到上下文向量。
    • Decoder在当前时间步会将 当前输入(上一个时间步生成的词)、上一个时间步的隐状态 和 注意力上下文向量 拼接起来,输入到LSTM或GRU中,更新隐状态并生成当前时间步的输出。
  4. Seq2Seq

    • 将Encoder和Decoder结合,逐步生成目标序列。
    • 使用了教师强制机制来控制训练时的输入。
Decoder代码详细解释:
  1. attn_weights = self.attention(hidden[-1], encoder_outputs):

    • hidden[-1] 是Decoder当前时间步的最后一层隐状态(对于多层LSTM来说)。encoder_outputs 是Encoder所有时间步的输出。
    • 调用 self.attention 计算当前时间步的注意力权重。
  2. attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs):

    • attn_weights 是注意力权重,形状为 [batch_size, src_len]
    • unsqueeze(1) 将其变为 [batch_size, 1, src_len],然后与 encoder_outputs(形状为 [batch_size, src_len, hidden_size])进行批量矩阵乘法(torch.bmm)。
    • 这样得到的结果 attn_applied 是加权后的上下文向量,形状为 [batch_size, 1, hidden_size],表示根据注意力权重加权后的源句子信息。
  3. torch.cat((embedded, attn_applied), dim=2):

    • 将Decoder的当前输入(嵌入表示)和上下文向量拼接在一起,输入到LSTM中。

六、总结:

Attention机制的引入,允许Decoder在生成每个目标词时,能够动态地根据源句子的不同部分调整注意力,使得模型能够处理更长的序列,并提高生成结果的准确性。Attention机制在机器翻译等任务中取得了显著的效果,并且为之后的Transformer等模型的出现奠定了基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2216366.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

硬盘文件误删:原因、恢复方案与预防措施

一、硬盘文件误删现象描述 在日常使用电脑的过程中&#xff0c;硬盘文件误删是一个常见且令人头疼的问题。许多用户在进行文件整理、删除无用资料或进行系统清理时&#xff0c;一不小心就可能将重要文件误删。这些误删的文件可能包括工作文档、学习资料、家庭照片、视频等&…

【含文档】基于Springboot+Vue的采购管理系统(含源码+数据库+lw)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统定…

SpringBoot实现桂林旅游的智能推荐

3系统分析 3.1可行性分析 通过对本桂林旅游景点导游平台实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本桂林旅游景点导游平台采用SSM框架&#xff0c;JAVA作…

基于Docker安装Grafana及其基本功能

Grafana是一款用Go语言开发的开源数据可视化工具&#xff0c;可以做数据监控和数据统计&#xff0c;带有告警功能。 拉取Grafana镜像 docker pull grafana/grafana 运行镜像 docker run -d -p 3000:3000 --namegrafana grafana/grafana 打开浏览器&#xff0c;访问 http://l…

【Vue】Vue2(10)

文章目录 1 过度与动画1.1 Test.vue1.2 Test2.vue1.3 Test3.vue1.4 TodoList_动画&#xff1a;MyItem.vue 2 配置代理服务器2.1 方法一2.2 方法二2.3 vue.config.js2.4 App.vue 3 github搜索案例3.1 静态页面3.2 Search.vue3.3 List.vue3.4 App.vue3.5 main.js3.6 github搜索案…

免费插件集-illustrator插件-Ai插件-路径点到点连线

文章目录 1.介绍2.安装3.通过窗口>扩展>知了插件4.功能解释5.总结 1.介绍 本文介绍一款免费插件&#xff0c;加强illustrator使用人员工作效率&#xff0c;实现简单路径内部点到点连线功能。首先从下载网址下载这款插件 https://download.csdn.net/download/m0_67316550…

打造卓越APP体验:13款界面设计软件推荐

你知道如何选择正确的UI设计软件吗&#xff1f;你知道设计美观的用户界面&#xff0c;及带来良好用户体验的APP&#xff0c;需要什么界面设计软件吗&#xff1f;基于APP界面的功能不同&#xff0c;选择的APP界面设计软件也会有所不同。然而&#xff0c;并不是要把所有APP界面设…

1.2.3 TCP IP模型

TCP/IP模型&#xff08;接网叔用&#xff09; 网络接口层 网络层 传输层 应用层 理念&#xff1a;如果某些应用需要“数据格式转换”“会话管理功能”&#xff0c;就交给应用层的特定协议去实现 tip&#xff1a;数据 局部正确不等于全局正确 但是&#xff0c;数据的 全局正…

docker (desktopcompose) download

docker docker-compose download 百度网盘获取离线包链接release-notes 参考dockerdocker-composewlspowershell

基于Spring Boot的大创项目成本控制系统

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

Linux下ClamAV源代码安装与使用说明

Linux下ClamAV源代码安装与使用说明 ClamAV(Clam AntiVirus)是一款开源的防病毒工具,广泛应用于Linux平台上的网络安全领域。它以其高效的性能和灵活的配置选项,成为网络安全从业人员的重要工具。ClamAV支持多线程扫描,可以自动升级病毒库,并且支持多个操作系统,包括Li…

扫普通链接二维码打开小程序

1. 2.新增规则&#xff08;注意下载文件到跟目录下&#xff0c;需要建个文件夹放下载的校验文件&#xff09; 3.发布 ps&#xff1a;发布后&#xff0c;只能访问正式版本。体验版本如果加了 测试链接http://xxx/xsc/10 那么http://xxx/xsc/aa.....应该都能访问 例如aa101 aa…

5 -《本地部署开源大模型》在Ubuntu 22.04系统下ChatGLM3-6B高效微调实战

在Ubuntu 22.04系统下ChatGLM3-6B高效微调实战 无论是在单机单卡&#xff08;一台机器上只有一块GPU&#xff09;还是单机多卡&#xff08;一台机器上有多块GPU&#xff09;的硬件配置上启动ChatGLM3-6B模型&#xff0c;其前置环境配置和项目文件是相同的。如果大家对配置过程还…

前端excel的实现方案Luckysheet

一、介绍 Luckysheet是一款纯前端类似excel的在线表格&#xff0c;功能强大、配置简单、完全开源的插件。目前已暂停维护&#xff0c;但是其已有功能大概能满足常见需求的使用。 二、引入 ①cdn引入&#xff08;目前应该已经不支持&#xff0c;可自行尝试&#xff09; <l…

第二十七篇:传输层讲解,TCP系列一

一、传输层的功能 ① 分割与重组数据 传输层也要做数据分割&#xff0c;所以必然也需要做数据重组。 ② 按端口号寻址 IP只能定位数据哪台主机&#xff0c;无法判断数据报文应该交给哪个应用&#xff0c;传输层给每个应用都设置了一个编号&#xff0c;这个编号就是端口&…

大数据毕业设计选题推荐-电影数据分析系统-电影推荐系统-Python数据可视化-Hive-Hadoop-Spark

✨作者主页&#xff1a;IT研究室✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

大模型应用开发:如何在网页中嵌入3D人物

要实现的效果如图所示&#xff1a; 左侧是插入的3D人物&#xff0c;类似AI智能助手的角色。 我们这里是通过React做的。需要用到以下工具或者网站&#xff1a; readyplayer.me/ 自定义3D人物Blender 3维设计软件&#xff0c;3D文件格式转化&#xff0c;主要是fbx和glb的互转w…

【Docker】安装部署项目流程(Pycharm版)

安装部署步骤 1.准备项目 第一步要准备好你所需要部署的项目&#xff0c;确保在工作目录下所以程序.py文件正常调用并能正确运行 如上&#xff0c;main要在工作目录中能跑通&#xff0c;这里有一点需要注意 在IDE src不要标记为源代码根目录&#xff0c;观察一下是否能跑通代…

React国际化中英文切换实现

目录 概况 安装 文件结构 引入 使用 正常使用 传参使用 概况 react-intl-universal 是一个国际化库&#xff0c;专门为 React 应用提供多语言支持。与 React 原生的 react-intl 相比&#xff0c;react-intl-universal 支持从远程服务器加载语言包&#xff0c;动态切换语…

【途牛旅游网-注册/登录安全分析报告】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞…