机器翻译之多头注意力(MultiAttentionn)在Seq2Seq的应用

news2024/12/30 1:50:16

目录

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

2.2验证上述封装的代码 

2.3 创建 添加了Bahdanau的decoder 

 2.4训练

 2.5预测

3.知识点个人理解 


 

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

class MultiHeadAttention(nn.Module):
    #初始化属性和方法
    def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        """
        query_size_size: query_size的特征数features
        key_size: key_size的特征数features
        value_size: value_size的特征数features
        num_hiddens:隐藏层的神经元的数量
        num_heads:多头注意力的header的数量
        dropout: 释放模型需要计算的参数的比例
        bias=False:没有偏差
        **kwargs : 不定长度的关键字参数
        """
        super().__init__(**kwargs)
        #接收参数
        self.num_heads = num_heads
        #初始化注意力,    #使用DotProductAttention时, keys与 values具有相同的长度, 经过decoder,他们长度相同
        self.attention = dltools.DotProductAttention(dropout)
        #初始化四个w模型参数
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        def transpose_qkv(X, num_heads):
            """实现queries, keys, values的数据维度转化"""
            #输入的X的shape=(batch_size, 查询数/键值对数量, num_hiddens)
            #这里,不能直接用reshape,需要索引维度,防止数据不能一一对应
            X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)   #将原维度的num_hiddens拆分成num_heads, -1,  -1相当于num_hiddens/num_heads的数值
            X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, num_size, 查询数/键值对数量, num_hiddens/num_heads)
            return X.reshape(-1, X.shape[2], X.shape[3])  #X的shape=(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)

        def transpose_outputs(X, num_heads):
            """逆转transpose_qkv的操作"""
            #此时数据的X的shape =(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)
            X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])  #X的shape=(batch_size, num_heads, 查询数/键值对数量, num_hiddens/num_heads)
            X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, 查询数/键值对数量, num_heads,  num_hiddens/num_heads)
            return X.reshape(X.shape[0], X.shape[1], -1)  #X的shape还原了=(batch_size, 查询数/键值对数, num_hiddens)

        #queries, keys, values,传入的shape=(batch_size, 查询数/键值对数, num_hiddens)
        #获取转换维度之后的queries, keys, values,
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        #若valid_len不为空,存在
        if valid_lens is not None:
            #将valid_lens重复数据self.num_heads次,在0维度上
            valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)
        #若为空,什么都不做,跳出if判断,继续执行其他代码

        #通过注意力函数获取输出outputs
        #outputs的shape = (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)
        outputs = self.attention(queries, keys, values, valid_lens)

        #逆转outputs的维度
        outputs_concat = transpose_outputs(outputs, self.num_heads)

        return self.W_o(outputs_concat)

2.2验证上述封装的代码 

#假设变量
num_hiddens, num_heads, dropout = 100, 5, 0.2
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
attention.eval()  #需要预测,加上
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
#假设变量
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])

X = torch.ones((batch_size, num_queries, num_hiddens))  #shape(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  #shape(2,6,100) 

attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

2.3 创建 添加了Bahdanau的decoder 

# 添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs : (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # X : (batch_size, num_steps, vocab_size)
        X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens

            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)

            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)

            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)

            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
            

        outputs = self.dense(torch.cat(outputs, dim=0))

        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

 2.4训练

# 训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)

encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)

decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)

net = dltools.EncoderDecoder(encoder, decoder)

dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 2.5预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('trouvez tom .', []), bleu 0.000
i'm home . => ('je suis chez moi .', []), bleu 1.000

3.知识点个人理解 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2151284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云服务器使用

最近搭建一个内网穿透工具,推荐一个云服务器: 三丰台:https://www.sanfengyun.com/ 作为学生党这个服务器是免费的可以体验使用!可以使用免费虚拟主机和云服务器,写一个申请的基本步骤方便大家构建 申请步骤&#x…

11.1图像的腐蚀和膨胀

基本概念-图像腐蚀 图像腐蚀是一种用于去除图像中小的对象或者突出物体边缘的形态学操作。 图像腐蚀(erosion)的基本概念 图像腐蚀通常用于二值图像,其基本原理是从图像中“侵蚀”掉一些像素点,这些像素点通常是边界上的或者是孤…

Word中引用参考文献和公式编号的方法

文章目录 应用参考文献对于单个文献引用多于多个文献同时引用 公式编号手动编号自动编号 参考: 应用参考文献 对于单个文献引用 word中的参考文献用交叉应用实现。 首先,将参考文献编号: 然后,在需要引用的地方用交叉引用插入…

VM虚拟机使用的镜像文件下载

文章目录 Windows系统进入微软官网下载工具以Windows10为例下载镜像文件 Windows系统 进入微软官网下载工具 微软中国官网:https://www.microsoft.com/zh-cn/ 以Windows10为例下载镜像文件 选择下载的路径 开始下载 安装windows10操作系统出现Time out问题及解决办…

【AI视频】Runway:Gen-2 运镜详解

博客主页: [小ᶻZ࿆] 本文专栏: AI视频 | Runway 文章目录 💯前言💯Camera Control(运镜)💯Camera Control功能测试Horizonta(左右平移)Vertical(上下平移&#xff0…

Python 中的 Kombu 类库

Kombu 是一个用于 Python 的消息队列库,提供了高效、灵活的消息传递机制。它是 Celery 的核心组件之一,但也可以单独使用。Kombu 支持多种消息代理(如 RabbitMQ、Redis、Amazon SQS 等),并提供了消息生产者和消费者的功…

ByteTrack多目标跟踪流程图

ByteTrack多目标跟踪流程图 点个赞吧,谢谢。

用 Pygame 实现一个乒乓球游戏

用 Pygame 实现一个乒乓球游戏 伸手需要一瞬间,牵手却要很多年,无论你遇见谁,他都是你生命该出现的人,绝非偶然。若无相欠,怎会相见。 引言 在这篇文章中,我将带领大家使用 Pygame 库开发一个简单的乒乓球…

python文字转wav音频

借鉴博客 一.前期准备 1. pip install baidu-aip 2. pip install pydub 3. sudo apt-get install ffmpeg 二.代码 from aip import AipSpeech from pydub import AudioSegment import time#input your own APP_ID/API_KEY/SECRET_KEY APP_ID 14891501 API_KEY EIm2iXtvD…

【可变模板参数】

文章目录 可变参数模板的概念可变参数模板的定义方式参数包的展开方式递归展开参数包逗号表达式展开参数包 STL容器中的emplace相关接口函数 可变参数模板的概念 可变参数模板是C11新增的最强大的特性之一,它对参数高度泛化,能够让我们创建可以接受可变…

linux入门——“linux基本指令”下

1.mv指令 mv指令用于移动文件或者目录。语法是mv 源文件 目标文件。它的用法需要注意: 当目标文件不存在的时候,默认是将源文件进行重命名操作,名字就是目标文件的名字,当目标文件存在的时候才会把源文件移动到目标文件。 目标文…

centos 7.9安装k8s

前言 Kubernetes单词来自于希腊语,含义是领航员,生产环境级别的容器编排技术,可实现容器的自动部署扩容以及管理。Kubernetes也称为K8S,其中8代表中间8个字符,是Google在2014年的开源的一个容器编排引擎技术&#xff…

WebLogic 后台弱⼝令GetShell

漏洞描述 通过弱⼝令进⼊后台界⾯ , 上传部署war包 , getshell 影响范围 全版本(前提后台存在弱⼝令) 环境搭建 cd vulhub-master/weblogic/weak_password docker-compose up -d 漏洞复现 默认账号密码:weblogic/Oracle123 weblogic…

SQL编程题复习(24/9/19)

练习题 x25 10-145 查询S001学生选修而S003学生未选修的课程(MSSQL)10-146 检索出 sc表中至少选修了’C001’与’C002’课程的学生学号10-147 查询平均分高于60分的课程(MSSQL)10-148 检索C002号课程的成绩最高的二人学号&#xf…

34. 模型材质父类Material

学习到现在大家对threejs的材质都有简单的了解,本节课主要结合文档,从JavaScript语法角度,给大家总结一下材质API的语法。 材质父类Material 查询threejs文档,你可以看到基础网格材质MeshBasicMaterial、漫反射网格材质MeshLamb…

-bash: apt-get: command not found -bash: yum: command not found

1. 现象: 1.1. 容器内使用apt-get, yum 提示命令未找到 1.2. dockerfile制作镜像时候,使用apt-get, yum同样报此错误。 2.原因: 2.1. linux 分为: 1. RedHat系列: Redhat、Centos、Fedora等 2. Debian系列&#xff1a…

4G 网络下资源加载失败?一次运营商封禁 IP 的案例分享

在工作中,网络问题是不可避免的挑战之一。最近,我们在项目中遇到了一起网络资源加载异常的问题:某同事在使用 4G 网络连接公司 VPN 时,云服务的前端资源居然无法加载!通过一系列的排查和分析,我们发现问题的…

alias 后门从入门到应急响应

目录 1. alias 后门介绍 2. alias 后门注入方式 2.1 方式一(以函数的方式执行) 2.2 方式二(执行python脚本) 3.应急响应 3.1 查看所有连接 3.2 通过PID查看异常连接的进程,以及该进程正在执行的命令行命令 3.3 查看别名 3.4 其他情况 3.5 那么检查这些…

Arthas jvm(查看当前JVM的信息)

文章目录 二、命令列表2.1 jvm相关命令2.1.3 jvm(查看当前JVM的信息) 二、命令列表 2.1 jvm相关命令 2.1.3 jvm(查看当前JVM的信息) 基础语法: jvm [arthas18139]$ jvmRUNTIME …

C++_21_模板

模板 简介&#xff1a; 一种用于实现通用编程的机制。 通过使用模板我们可以编写可复用的代码&#xff0c;可以适用于多种数据类型。 C模板的语法使用角括号 < > 来表示泛型类型&#xff0c;并使用关键字 template 来定义和声明模板 概念&#xff1a; c范式编程 特点&…