机器翻译之Bahdanau注意力机制在Seq2Seq中的应用

news2024/11/15 17:29:02

目录

1.创建 添加了Bahdanau的decoder 

2. 训练

 3.定义评估函数BLEU

 4.预测

 5.知识点个人理解


1.创建 添加了Bahdanau的decoder 

import torch
from torch import nn
import dltools


#定义注意力解码器基类
class AttentionDecoder(dltools.Decoder):  #继承dltools.Decoder写注意力编码器的基类
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    @property    #装饰器, 定义的函数方法可以像类的属性一样被调用
    def attention_weights(self):
        #raise用于引发(或抛出)异常
        raise NotImplementedError  #通常用于抽象基类中,作为占位符,提醒子类必须实现这个方法。 


#创建 添加了Bahdanau的decoder
#继承AttentionDecoder这个基类创建Seq2SeqAttentionDecoder子类, 子类必须实现父类中NotImplementedError占位的方法
class Seq2SeqAttentionDecoder(AttentionDecoder):  
    #初始化属性和方法
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        """
        vocab_size:此表大小,  相当于输入数据的特征数features,  也是输出数据的特征数
        embed_size:嵌入层的大小:将输入数据处理成小批量的数据
        num_hiddens:隐藏层神经元的数量
        num_layers:循环网络的层数
        dropout=0:不释放模型的参数(比如:神经元)
        """
        super().__init__(**kwargs)
        
        #初始化注意力机制的评分函数方法
        self.attention = dltools.AdditiveAttention(key_size=num_hiddens,
                                                   query_size=num_hiddens, 
                                                   num_hiddens=num_hiddens,
                                                   dropout=dropout)
        #初始化嵌入层:将输入的数据处理成小批量的tensor数据   (文本--->数值的映射转化)
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size)
        #初始化循环网络
        self.rnn = nn.GRU(embed_size+num_hiddens, num_hiddens, num_layers, dropout=dropout)
        #初始化线性层  (输出层)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    #初始化隐藏层的状态state   (计算state,需要编码器的输出结果、序列的有效长度)
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        #enc_outputs是一个元组(输出结果,隐藏状态)
        #outputs的shape=(batch_size, num_steps, num_hiddens)
        #hidden_state的shape=(num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        #返回一个元组(,),可以用一个变量接收
        #outputs.permute(1, 0, 2)转换数据的维度是因为rnn循环神经网络的输入要求是先num_steps,再batch_size,
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    #定义前向传播   (输入数据X,state)
    def forward(self, X, state):
        #变量赋值:接收编码器encoder的输出结果、隐藏状态、序列有效长度
        #enc_outputs的shape=(batch_size, num_steps, num_hiddens)
        #hidden_state的shape=(num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        #X的shape=(batch_size, num_steps, vocab_size)
        X = self.embedding(X)   #将X输入embedding嵌入层后, X的shape=(batch_size, num_steps, embed_size)
        #调换X的0维度和1维度数据
        X = X.permute(1, 0, 2)   #X的shape=(num_steps, batch_size, embed_size)
        outputs, self._attention_weights = [], []  #创建空列表,用于存储数据
        
        for x in X:  #遍历每一批数据
            #获取query
            #hidden_state[-1]表示最后一层循环网络的隐藏层状态  (有两层循环网络)
            #hidden_state[-1]的shape=(batch_size, num_hiddens)    #dim=1表示在原索引1的维度增加一个维度
            query = torch.unsqueeze(hidden_state[-1], dim=1)  
#             print('query的shape:', query.shape)   #query的shape=(batch_size, 1, num_hiddens)
            #通过注意力机制获取上下文序列
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context的shape:', context.shape)  #context的shape=(batch_size, 1, num_hiddens)
            #用最后一个维度 拼接context, x 数据
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x的shape:', x.shape)   #x的shape=(batch_size, 1, num_hiddens+embed_size)
            #将x和hidden_state输入循环神经网络中,获取输出结果和新的hidden_state
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out的shape:', out.shape)   #out的shape=(1, batch_size, num_hiddens)
#             print('hidden_state的shape:', hidden_state.shape) #两层循环层:hidden_state的shape=(2, batch_size, num_hiddens)
            #将输出结果添加到列表中
            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
        
        outputs = self.dense(torch.cat(outputs, dim=0))
#         print('outputs的shape:', outputs.shape)  #outputs的shape=(num_steps, batch_size, vocab_size)
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights



#测试代码
#创建编码器对象
encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
#需要预测, 要加encoder.eval()
encoder.eval()
#创建解码器对象
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()

#假设数据
batch_size, num_steps = 4, 7
X = torch.zeros((4, 7), dtype = torch.long)
#初始化状态state
state = decoder.init_state(encoder(X), None)
outputs, state = decoder(X, state)
#state包含三个东西(enc_outputs, hidden_state, enc_valid_lens)
#state[0]是 enc_outputs
#state[1]是 hidden_state, 两层循环层,就会有两个hidden_state, state[1][0]是第一层的hidden_state
outputs.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
query的shape: torch.Size([4, 1, 16])
context的shape: torch.Size([4, 1, 16])
x的shape: torch.Size([4, 1, 24])
out的shape: torch.Size([1, 4, 16])
hidden_state的shape: torch.Size([2, 4, 16])
outputs的shape: torch.Size([7, 4, 10])

Out[11]:

(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

2. 训练

#声明变量
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

#加载数据
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)

#创建编辑器对象
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
#创建编辑器对象
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)

#创建网络模型
net = dltools.EncoderDecoder(encoder, decoder)

#模型训练
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 

 3.定义评估函数BLEU

def bleu(pred_seq, label_seq, k):
    print('pred_seq:', pred_seq)
    print('label_seq:', label_seq)
    #将pred_seq, label_seq分别进行空格分隔
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    #获取pred_seq, label_seq的长度
    len_pred, len_label = len(pred_seq), len(label_seq)
    
    score = math.exp(min(0, 1 - (len_label / len_pred)))
    for n in range(1, k+1): #n的取值范围,  range()左闭右开
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i+n])] += 1
            
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i+n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i+n])] -=1
        score *= math.pow(num_matches / (len_pred -n + 1), math.pow(0.5, n))
    return score

 4.预测

import math
import collections


engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est bon .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

 5.知识点个人理解

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2152701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为什么年轻人都热衷找搭子,而不是找对象?

在繁华的都市中,有一个名叫晓悦的年轻人。晓悦每天穿梭于忙碌的工作和快节奏的生活之间,渐渐地,她发现身边的朋友们都开始找起了 “搭子”。 有一天,晓悦结束了一天疲惫的工作,坐在咖啡店里,看着窗外匆匆而…

在SpringBoot项目中利用Redission实现布隆过滤器(布隆过滤器的应用场景、布隆过滤器误判的情况、与位图相关的操作)

文章目录 1. 布隆过滤器的应用场景2. 在SpringBoot项目利用Redission实现布隆过滤器3. 布隆过滤器误判的情况4. 与位图相关的操作5. 可能遇到的问题(Redission是如何记录布隆过滤器的配置参数的)5.1 问题产生的原因5.2 解决方案5.2.1 方案一:…

DBeaver启动报错 Faild to load the JNI shared library

DBeaver启动报错 Faild to load the JNI shared library 问题现象 安装完成后,启动dbeaver报错 查看版本为64位版本,JAVA也为64为版本 dbeaver版本 java版本 解决 在dberver.ini添加指定配置,即可启动成功

msvcp100.dll是什么意思?msvcp100.dll丢失有什么可靠的解决方法

当我们在电脑中试图启动某些程序或游戏时,可能会遇到一个错误消息:"程序无法启动,因为计算机缺少msvcp100.dll"。其实遇到这种情况是非常的常见的,只要你是经常使用电脑的人,我们要解决它也非常的简单&#…

工作中遇到的问题总结(1)

文章目录 第一题问题描述解决思路 第二题问题描述解决思路核心大表如何优化数据迁移过程是怎么样的如何将流量从旧系统迁移到新系统上 第三题问题描述解决思路 第四题问题描述解决思路方案一:双写机制方案二:基于时间戳的分流机制方案三:灰度…

数据结构之线性表——LeetCode:707. 设计链表,206. 反转链表,92. 反转链表 II

707. 设计链表 题目描述 707. 设计链表 你可以选择使用单链表或者双链表,设计并实现自己的链表。 单链表中的节点应该具备两个属性:val 和 next 。val 是当前节点的值,next 是指向下一个节点的指针/引用。 如果是双向链表,则…

【滑动窗口】算法总结

文章目录 滑动窗口算法总结1.暴力求解vs滑动窗口2.需要注意的细节问题 2.滑动窗口的基本模板1.非固定窗口大小的滑动窗口2.固定窗口大小的滑动窗口细节 滑动窗口算法总结 1.暴力求解vs滑动窗口 遇到那些可以转化成一个子数组的长度的问题时,往往需要用到双指针。 …

二,MyBatis -Plus 关于映射 Java Bean 对象的注意事项和细节(详细说明)

二,MyBatis -Plus 关于映射 Java Bean 对象的注意事项和细节(详细说明) 文章目录 二,MyBatis -Plus 关于映射 Java Bean 对象的注意事项和细节(详细说明)1. 映射2. 表的映射3. 字段映射4. 字段失效5. 视图属性6. 总结:7. 最后: 1.…

【C/C++】速通涉及string类的经典编程题

【C/C】速通涉及string类的经典编程题 一.字符串最后一个单词的长度代码实现:(含注释) 二.验证回文串解法一:代码实现:(含注释) 解法二:(推荐)1. 函数isalnum…

单卡3090 选用lora微调ChatGLM3-6B

环境配置 Python 3.10.12 transformers 4.36.2 torch 2.0.1 下载demo代码 在官方网址https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo 下载demo代码cd 进入文件夹 pip install -r requirements.txt 安装一些包 基本知识 SFT 全量微调: 4张显卡平均分配&#…

13年计算机考研408-数据结构

解析: 这个降序链表不影响时间复杂度,因为是链表,所以你想要升序就使用头插法,你想要降序就使用尾插法。 然后我们来分析一下最坏的情况是什么样的。 因为m和n都是两个有序的升序序列。 如果刚好m的最大值小于n的最小值&#xff0…

AI宠物拟人化新玩法,教你如何用0成本打造爆款创意内容!

近年来,随着AI技术的快速发展,各种创新玩法不断涌现,尤其是在内容创作领域,AI带来的变革尤为显著。 **其中,宠物拟人化逐渐成为社交媒体上的一大热门话题。**通过AI生成工具,我们不仅可以将宠物拟人化&…

Snapchat API 访问:Objective-C 实现示例

Snapchat 是一个流行的社交媒体平台,它允许用户发送和接收短暂存在的图片和视频。对于开发者来说,访问 Snapchat API 可以为应用程序添加独特的社交功能。本文将介绍如何在 Objective-C 中实现对 Snapchat API 的访问,并提供一个详细的代码示…

GD32F103单片机-EXTI外部中断

GD32F103单片机-EXTI外部中断 一、EXTI及NVIC介绍二、编程实验2.1 相关库函数2.2 实验代码 一、EXTI及NVIC介绍 GD32和STM32的EXTI基本相似,具体见STM32F1单片机-外部中断GD32的EXTI包括20个相互独立的边沿检测电路请求产生中断或事件,4位优先级配置寄存…

热像仪是如何工作的?

红外热像仪是一种非接触式设备,能够检测红外能量(热量)并将其转变成可见光图像。让我们深入了解红外热像仪的科学原理,以及借助红外热像仪我们能够看到的隐形世界。 捕捉红外波,而不是可见光 首先必须清楚的是&#…

windows环境下配置MySQL主从启动失败 查看data文件夹中.err发现报错unknown variable ‘log‐bin=mysql‐bin‘

文章目录 问题解决方法 问题 今天在windows环境下配置MySQL主从同步,在修改my.ini文件后发现MySQL启动失败了 打开my.ini检查参数发现没有问题 [mysqld] #开启二进制日志,记录了所有更改数据库数据的SQL语句 log‐bin mysql‐bin #设置服务id&#x…

Vue(13)——router-link

router-link vue-router提供了一个全局组件router-link(取代a标签) 能跳转,配置to属性指定路径(必须)。本质还是a标签。默认会提供高亮类名,可以直接设置高亮样式 右键检查,发现多了两个类: 可以直接写样式…

Java数据结构专栏介绍

专栏导读 在软件工程的世界里,数据结构是构建高效、可靠程序的基石。"Java数据结构"专栏致力于为Java开发者提供一个全面、深入的学习平台,帮助他们掌握各种数据结构的原理、实现及其在Java中的应用。通过这个专栏,读者将能够提升…

IPsec-Vpn

网络括谱图 IPSec-VPN 配置思路 1 配置IP地址 FWA:IP地址的配置 [FW1000-A]interface GigabitEthernet 1/0/0 [FW1000-A-GigabitEthernet1/0/0]ip address 10.1.1.1 24 [FW1000-A]interface GigabitEthernet 1/0/2 [FW1000-A-GigabitEthernet1/0/2]ip address

分布式Id生成策略-美团Leaf

之前在做物流相关的项目时候,需要在分布式系统生成运单的id。 1.需求: 1.全局唯一性:不能出现重复的ID。(基本要求) 2.递增:大多数关系型数据库(如 MySQL)使用 B 树作为索引结构。…