【模型】Encoder-Decoder模型中Attention机制的引入

news2025/1/13 9:42:21

Encoder-Decoder 模型中引入 Attention 机制,是为了改善基本Seq2Seq模型的性能,特别是当处理长序列时,传统的Encoder-Decoder模型容易面临信息压缩的困难。Attention机制可以帮助模型动态地选择源序列中相关的信息,从而提高翻译等任务的质量。

一、为什么需要Attention机制?

在基本的 Encoder-Decoder 模型中,Encoder将整个源句子的所有信息压缩成一个固定大小的向量(上下文向量),然后Decoder使用这个向量来生成目标序列。这个单一的上下文向量对于较短的句子可能足够,但对于较长的句子,模型可能无法有效捕捉到整个句子中所有重要的信息。这样容易导致信息丢失,尤其是当句子很长时,Decoder在生成目标词时可能无法获取到源句子的细节信息。

二、Attention机制的核心思想

Attention机制的核心思想是:在每个时间步生成目标单词时,Decoder不再依赖于固定的上下文向量,而是能够通过“注意力”权重,动态地从源句子的所有隐状态中选择最相关的部分。这样,Decoder每生成一个目标词时,能够更好地“关注”源句子中与当前生成词最相关的部分。

三、Attention机制的工作流程

在每一步解码时,Attention机制会根据Decoder的当前状态计算出一组权重,表示源句子中各个位置的隐状态对当前解码步骤的重要性。这些权重用于加权源句子的隐状态,以得到一个上下文向量,这个上下文向量会与当前Decoder的隐状态一起用于生成下一个目标词。

Attention的具体步骤如下:

  1. 计算注意力权重

    • 对于Decoder的每一步(生成每个目标词时),通过Decoder的当前隐状态和源句子每个时间步的隐状态来计算注意力权重。
    • 这些权重表示源句子中每个位置的重要性,可以使用加性Attention点积Attention来计算。
  2. 计算上下文向量

    • 通过将注意力权重与源句子的隐状态进行加权平均,得到一个新的上下文向量。
    • 这个上下文向量包含了源句子中当前对Decoder最重要的信息。
  3. 解码下一步

    • 将新的上下文向量与当前Decoder的隐状态结合,用于生成当前的目标词。

四、Attention机制的公式

对于每个时间步 t:

  1. 计算注意力得分:通常使用Decoder当前的隐状态 ht 和源句子每个位置的隐状态 hs 计算注意力得分,可以通过以下公式计算:

在这里插入图片描述

常见的 score 函数有加性(Bahdanau Attention)和点积(Luong Attention):

  • 加性Attention:使用一个简单的前馈网络对 ht 和 hs 进行线性变换并加和。
  • 点积Attention:直接计算 ht 和 hs 的点积。
  1. 计算注意力权重:对得分 et,s​ 进行Softmax操作,得到权重:

在这里插入图片描述

这些权重 αt,s 表示源句子中各个位置对当前解码的影响力。

  1. 计算上下文向量:使用注意力权重对源句子的隐状态进行加权平均,得到上下文向量 ct:

在这里插入图片描述

  1. 生成下一个词:将上下文向量 ct 与Decoder的隐状态 ht 结合,生成下一个词。

五、引入Attention机制的Encoder-Decoder代码实现

以下是一个带有 Attention 机制的 Encoder-Decoder 模型的简化实现,使用 PyTorch 进行构建。

import torch
import torch.nn as nn

# Encoder模型
class Encoder(nn.Module):
    def __init__(self, input_size, embedding_dim, hidden_size):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)  # [batch_size, src_len, embedding_dim]
        outputs, (hidden, cell) = self.lstm(embedded)  # [batch_size, src_len, hidden_size]
        return outputs, hidden, cell

# Attention模型
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        src_len = encoder_outputs.shape[1]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [batch_size, src_len, hidden_size]
        energy = torch.sum(self.v * energy, dim=2)  # [batch_size, src_len]
        return torch.softmax(energy, dim=1)  # [batch_size, src_len]

# Decoder模型
class Decoder(nn.Module):
    def __init__(self, output_size, embedding_dim, hidden_size):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + hidden_size, hidden_size, batch_first=True)
        self.fc_out = nn.Linear(hidden_size * 2, output_size)
        self.attention = Attention(hidden_size)

    def forward(self, input_token, hidden, cell, encoder_outputs):
        input_token = input_token.unsqueeze(1)  # [batch_size, 1]
        embedded = self.embedding(input_token)  # [batch_size, 1, embedding_dim]
        
        # 计算注意力权重
        attn_weights = self.attention(hidden[-1], encoder_outputs)  # [batch_size, src_len]
        
        # 使用注意力权重对encoder输出进行加权平均
        attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)  # [batch_size, 1, hidden_size]

        # 将注意力上下文向量和嵌入层输入拼接
        lstm_input = torch.cat((embedded, attn_applied), dim=2)  # [batch_size, 1, embedding_dim + hidden_size]
        
        # 通过LSTM
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))  # [batch_size, 1, hidden_size]

        # 生成最终输出
        output = torch.cat((output.squeeze(1), attn_applied.squeeze(1)), dim=1)  # [batch_size, hidden_size * 2]
        prediction = self.fc_out(output)  # [batch_size, output_size]

        return prediction, hidden, cell

# Seq2Seq模型
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size = tgt.shape[0]
        target_len = tgt.shape[1]
        target_vocab_size = self.decoder.fc_out.out_features

        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)

        encoder_outputs, hidden, cell = self.encoder(src)

        input_token = tgt[:, 0]

        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell, encoder_outputs)
            outputs[:, t, :] = output
            top1 = output.argmax(1)
            input_token = tgt[:, t] if torch.rand(1).item() < teacher_forcing_ratio else top1

        return outputs
代码说明:
  1. Encoder

    • 编码源句子,生成隐状态和输出序列。
    • 输出序列会在注意力机制中使用。
  2. Attention

    • Attention 模型根据当前隐状态和Encoder输出计算注意力权重。
    • 使用注意力权重对Encoder输出进行加权平均,得到上下文向量。
    • 将输入的词向量和上下文向量连接,输入到LSTM中
  3. Decoder

    • Decoder在当前时间步会将 当前输入(上一个时间步生成的词)、上一个时间步的隐状态 和 注意力上下文向量 拼接起来,输入到LSTM或GRU中,更新隐状态并生成当前时间步的输出。
  4. Seq2Seq

    • 将Encoder和Decoder结合,逐步生成目标序列。
    • 使用了教师强制机制来控制训练时的输入。
Decoder代码详细解释:
  1. attn_weights = self.attention(hidden[-1], encoder_outputs):

    • hidden[-1] 是Decoder当前时间步的最后一层隐状态(对于多层LSTM来说)。encoder_outputs 是Encoder所有时间步的输出。
    • 调用 self.attention 计算当前时间步的注意力权重。
  2. attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs):

    • attn_weights 是注意力权重,形状为 [batch_size, src_len]
    • unsqueeze(1) 将其变为 [batch_size, 1, src_len],然后与 encoder_outputs(形状为 [batch_size, src_len, hidden_size])进行批量矩阵乘法(torch.bmm)。
    • 这样得到的结果 attn_applied 是加权后的上下文向量,形状为 [batch_size, 1, hidden_size],表示根据注意力权重加权后的源句子信息。
  3. torch.cat((embedded, attn_applied), dim=2):

    • 将Decoder的当前输入(嵌入表示)和上下文向量拼接在一起,输入到LSTM中。

六、总结:

Attention机制的引入,允许Decoder在生成每个目标词时,能够动态地根据源句子的不同部分调整注意力,使得模型能够处理更长的序列,并提高生成结果的准确性。Attention机制在机器翻译等任务中取得了显著的效果,并且为之后的Transformer等模型的出现奠定了基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2208797.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

06-ArcGIS For JavaScript-requestAnimationFrame动画渲染

文章目录 概述setInterval&#xff08;&#xff09;与setTimeout()requestAnimationFrame()requestAnimationFrame在ArcGIS For JavaScript的应用结果 概述 本节主要讲解与时间相关的三个方法setTimeout()、setInterval()和requestAnimationFrame()&#xff0c;这三个方法都属…

怎样提高员工效率与客户转化率?

在当今商业竞争日益激烈的市场环境中&#xff0c;企业越来越注重提升员工的工作效率和客户的转化率。以下是几种高效的策略&#xff0c;可以帮助企业激发新的业务活力&#xff1a; 1、自动快捷回复消息 自动快捷回复功能&#xff0c;是提升客户沟通效率的有力工具。通过设置自…

注册电气工程师印章要求

一、边框 1.尺寸&#xff1a;长63mm、宽28mm、线宽&#xff1a;0.6mm 2.第一格&#xff1a;宽7.25mm 3.第二格&#xff1a;宽19.2mm 二、文字 1.第一行 名称&#xff1a;行长59.50mm 字高5.61mm 字体 宋体 2.第二行 姓名&#xff1a;行长42.00mm 字高5.28mm 字体 姓名 宋体 人名…

GStreamer 简明教程(六):利用 Tee 复制流数据,巧用 Queue 实现多线程

系列文章目录 GStreamer 简明教程&#xff08;一&#xff09;&#xff1a;环境搭建&#xff0c;运行 Basic Tutorial 1 Hello world! GStreamer 简明教程&#xff08;二&#xff09;&#xff1a;基本概念介绍&#xff0c;Element 和 Pipeline GStreamer 简明教程&#xff08;三…

超声波清洗机靠谱吗?适合学生党入手的四款眼镜清洗机品牌推荐!

有没有学生党还不知道双十一买什么&#xff1f;其实可以去看看超声波清洗机&#xff0c;说实话它的实用性真的很高&#xff0c;对于日常用于清洗眼镜真的非常合适&#xff0c;不仅可以帮助大家节约时间而且还能把眼镜清洗的干净透亮&#xff0c;接下来我就来为大家带来四款好用…

浅谈钓鱼攻防之道-制作免杀word文件钓鱼

梦里明明有六趣&#xff0c;觉后空空无大千 1、制作基本的word宏文件 Cobalt Strike生成宏代码 选择监听器 成功生成宏文件 新建word文档&#xff0c;点击视图——宏——查看宏 选择编辑 点击视图中的工程资源管理器 选择本文件中ThisDocument&#xff0c;将cs生成的文件复制…

MySql表结构设计

创建 create table 表名(字段1 字段类型 [约束] [comment 字段1注释],...) [comment 表注释];约束是作用于表中字段上的规则&#xff0c;用于限制存储在表中的数据。它的目的是保证数据库中数据的正确性、有效性和完整性。 约束描述关键字非空约束限制该字段不能为nullnot nu…

某宝228滑块和普通228滑块之间的区别

声明&#xff1a; 该文章为学习使用&#xff0c;严禁用于商业用途和非法用途&#xff0c;违者后果自负&#xff0c;由此产生的一切后果均与作者无关。 本文章未经许可禁止转载&#xff0c;禁止任何修改后二次传播&#xff0c;擅自使用本文讲解的技术而导致的任何意外&#xff…

C++11 新特性 学习笔记

C11 新特性 | 侯捷C11学习笔记 笔者作为侯捷C11新特性课程的笔记进行记录&#xff0c;供自己查阅方便 文章目录 C11 新特性 | 侯捷C11学习笔记1.Variadic TemplatesC11支持函数模板的默认模板参数C11在函数模板和类模板中使用可变参数 可变参数模板1) 可变参数函数模板2) 可变…

Spring Boot 项目中 Redis 与数据库性能对比实战:从缓存配置到时间分析,详解最佳实践

一、前言&#xff1a; 在现代应用中&#xff0c;随着数据量的增大和访问频率的提高&#xff0c;如何提高数据存取的性能变得尤为重要。缓存技术作为一种常见的优化手段&#xff0c;被广泛应用于减少数据库访问压力、提升系统响应速度。Redis 作为一种高效的内存缓存数据库&…

DBeaver连接mysql 9报错:Public Key Retrieval is not allowed

DBeaver连接mysql 9报错&#xff1a;Public Key Retrieval is not allowed 如图&#xff1a; 解决方案 编辑连接属性&#xff1a; 修改 allowPublicKeyRetrieval 的值为 true DBeaver连接mysql数据库执行.sql脚本&#xff0c;Windows_dbeaver执行sql脚本.sql文件-CSDN博客文章…

【Java】 —— 数据结构与集合源码:Vector、LinkedList在JDK8中的源码剖析

目录 7.2.4 Vector部分源码分析 7.3 链表LinkedList 7.3.1 链表与动态数组的区别 7.3.2 LinkedList源码分析 启示与开发建议 7.2.4 Vector部分源码分析 jdk1.8.0_271中&#xff1a; //属性 protected Object[] elementData; protected int elementCount;//构造器 public …

2024ccna考试时间?新手小白看这些就够了

2024年想要考取ccna证书的新手小白们&#xff0c;是不是正在为考试时间而烦恼呀&#xff0c;其实ccna的考试时间其实非常灵活&#xff0c;并不需要像其他考试那样死记硬背固定的日期。那么小编马上就给大家说说2024ccna考试时间&#xff0c;并且附带一些考试内容&#xff0c;让…

LeetCode[简单] 70. 爬楼梯

假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢&#xff1f; 思路 利用滚动数组 public class Solution {public int ClimbStairs(int n) { //滚动数组int f0 0, f1 0, f2 1;for(int i 1; i < n; i){…

拱坝与重力坝:结构特性与应用差异的深度解析

在水利工程领域&#xff0c;坝体结构的选择对于工程的稳定性、安全性以及经济效益具有至关重要的影响。拱坝与重力坝作为两种主要的坝型&#xff0c;各自具有独特的结构特性和应用场景。本文旨在深入探讨拱坝与重力坝的区别&#xff0c;从工程特点、坝体结构型式、材料选用、水…

国内如何下载谷歌浏览器(chrome浏览器)历史版本和chromedriver驱动,长期更新,建议收藏

众所周知&#xff0c;google是一直被国内屏蔽的&#xff0c;有时候想要下载个chrome浏览器都要去外网&#xff0c;或者到处去搜索才能下载到。因为下载chrome浏览器的这个网址&#xff1a;google.com/chrome/ 在国内是一直被屏蔽掉的。 今天主要讲解的是国内ChromeDriver 的下…

ET实现游戏中的红点提示系统(服务端)

目录 ☝&#x1f913;前言 ☝&#x1f913;一、实现思路 ☝&#x1f913;二、实现 &#x1f920;2.1 定义红点组件 &#x1f920;2.2 定义Proto消息体 &#x1f920;2.3 RedPointComponentSystem ☝&#x1f913;难点 ☝&#x1f913;前言 当我们闲来无事时打开农药想消…

数据结构-5.3.二叉树的定义和基本术语

一.二叉树的基本概念&#xff1a; 树是一种递归定义的数据结构&#xff0c;因此二叉树是递归定义的数据结构。 二.二叉树的五种状态&#xff1a; 三.几个特殊的二叉树&#xff1a; 1.满二叉树&#xff1a;结点总数就是通过等比数列公式求出来的&#xff0c;首项为1即根节点&a…

【网络协议】TCP协议常用机制——延迟应答、捎带应答、面向字节流、异常处理,保姆级详解,建议收藏

&#x1f490;个人主页&#xff1a;初晴~ &#x1f4da;相关专栏&#xff1a;计算机网络那些事 前几篇文章&#xff0c;博主带大家梳理了一下TCP协议的几个核心机制&#xff0c;比如保证可靠性的 确认应答、超时重传 机制&#xff0c;和提高传输效率的 滑动窗口及其相关优化机…

C++ Builder XE12关于KonopkaControls与TMS VCL UI Pack组件的安装

1、先打开open project&#xff0c;选中安装的组件工程&#xff0c;并打开。 2、在option中设置 3、点击编译并进行安装install