经典网络解析(四) transformer | 自注意力、多头、发展

news2024/11/29 16:42:31

文章目录

  • 1 背景
    • 1.1 困境
    • 1.2 基本架构
  • 2 嵌入层
  • 3 编码器部分
    • 3.1 自注意力层
    • 3.2 多头注意力机制
    • 3.3 LayerNorm归一化层
  • 4 解码器
  • 5 transformer的发展
  • 6 代码

1 背景

1.1 困境

transformer可以并行训练,也是用来实现attention注意力机制

之前RNN的困境

(1)特征的有效性不够,某一时刻拿到了前一时刻的所有特征,特征没有聚焦点

(2)训练效率不够,必须得等前一时刻训练出来之后才能进行这一时刻的训练,而我们的transformer通过矩阵运算实现并行

注意力模型其实有很多

transformer是做的比较好的一类

1.2 基本架构

Encoder和Decoder构架

Encoder部分把原语言特征抽取出来 ,送到Decoder

堆叠多层Encoder和Decoder为什么?

特征提取能力可以更强

transformer

Encoder 抽取特征 多层累加形式(所以输入输出维度不变)例子中都是[L.512]维度

​ 比如有十个单词,那么经过Encoder以后就会输出十个特征,输入序列长度等于输出序列长度

Decoder 根据特征翻译出来找到对应的层

左边编码器堆叠六层 N=6

右边解码器堆叠六层 N=6

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2 嵌入层

input Embeding 嵌入层 把单词降为成512维 所以输入L个单词 512维 则输入为L×512 维,

将ont-hot 表示映射到连续空间上,可以用nn.Embedding

embeding嵌入层 一般进行降维 将到d_model (一般512维度)

​ 通过一个变换将单词的one-hot表示映射到连续空间上(降维)可以使用nn.Embedding实现

比如我们翻译L个单词,词典大小10000,则利用独热编码后,我们的inputs维度为[L,10000]

然后送入embeding层,变为[L,512]

3 编码器部分

我们想要把上下文信息编码进去,就是通过自注意力实现的

比如吃苹果,苹果手机

3.1 自注意力层

之前的注意力就是建立权重

采用如下方式产生权重

每个单词经过3个线性变换生成3个向量q,k,v ,用每个单词自己的q和其他单词的k做点乘,得到ai 然后再softmax 转换为权值大小,则最后的翻译参考的x等于权值×对应的v累加

最后再直接输出

在这里插入图片描述

输出 翻译得到单词的概率

注意这里的每一个xi都有一个输出,而且都利用了上下文的信息

具体计算的时候

(1)x1怎么到 q1,k1,v1呢,是通过三个矩阵Q1,K1,V1 这三个矩阵是可学习的,x1×Q1得到q1 ,x1×K1得到k2 ……其他同理,进而决定了我们的每一个x对应的q,k,v是多少?

(2)然后对于每一个单词的qi,和其余所有单词包括自己的k做点积,得到一组score ,然后除以维度开根号(主要解决有些输出过大有些输出过小数据分布方差过大的情况)然后再softmax操作就可以得到权重

(3)然后对于每一个单词的输出,将每一个单词的vi和所有其余单词的(2)得到的权重做点乘然后累加

以上所有操作都可以矩阵并行运算

3.2 多头注意力机制

经过一组参数,多头注意力就是学习多个矩阵Q,K,V,一模一样的操作,(实际上就是多组Q,K,V)得到八组, 形成八组结果

然后连接contact层,再经过一个线性变换层得到最后的结果

在这里插入图片描述

3.3 LayerNorm归一化层

注意LayerNorm和batch Norm的区别

batch Norm是一个数据集批次之间的操作

而layerNorm是我们的一组句子提取特征后进行的操作

残差 减均值除方差 前向神经网络,批归一化

全连接 做特征提取,加强非线性操作

4 解码器

解码器的输入是已经翻译过的语言特征,并且中途会插入原语言特征(来自编码器)

源语言的K,V传入Decoder

利用原语言提取的特征,加上已经翻译过的特征进行预测未翻译的语言的特征

通过Mask 就可以并行的训练了,已经翻译过来的内部关系

attention 就是权重

为什么用sin,cos位置编码

主要是可以记住相对位置,也可以记住绝对位置

5 transformer的发展

GPT

BERT模型 完形填空

GPT-2 GPT-3

探索了无监督的

完形填空,预测下一个词

Non-local

ViT 沿用Bert

6 代码

class Transformer(nn.Module):
    ''' A sequence to sequence model with attention mechanism. '''

    def __init__(
            self, n_src_vocab, n_trg_vocab, src_pad_idx, trg_pad_idx,
            d_word_vec=512, d_model=512, d_inner=2048,
            n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, n_position=200,
            trg_emb_prj_weight_sharing=True, emb_src_trg_weight_sharing=True):

        super().__init__()

        self.src_pad_idx, self.trg_pad_idx = src_pad_idx, trg_pad_idx

        # Encoder
        self.encoder = Encoder(
            n_src_vocab=n_src_vocab, n_position=n_position,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            pad_idx=src_pad_idx, dropout=dropout)

        # Decoder
        self.decoder = Decoder(
            n_trg_vocab=n_trg_vocab, n_position=n_position,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            pad_idx=trg_pad_idx, dropout=dropout)

        # 最后的linear输出层
        self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False)

        # xavier初始化
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 

        assert d_model == d_word_vec, \
        'To facilitate the residual connections, \
         the dimensions of all module outputs shall be the same.'

        self.x_logit_scale = 1.
        if trg_emb_prj_weight_sharing:
            # Share the weight between target word embedding & last dense layer
            self.trg_word_prj.weight = self.decoder.trg_word_emb.weight
            self.x_logit_scale = (d_model ** -0.5)

        if emb_src_trg_weight_sharing:
            self.encoder.src_word_emb.weight = self.decoder.trg_word_emb.weight


    def forward(self, src_seq, trg_seq):

        # mask
        src_mask = get_pad_mask(src_seq, self.src_pad_idx)
        trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq)

        # encoder & decoder
        enc_output, *_ = self.encoder(src_seq, src_mask)
        dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)
        
        # final linear layer得到logit vector
        seq_logit = self.trg_word_prj(dec_output) * self.x_logit_scale

        return seq_logit.view(-1, seq_logit.size(2))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1055184.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【PostgreSQL】【存储管理】表和元组的组织方式

外存管理负责处理数据库与外存介质(PostgreSQL8.4.1版本中只支持磁盘的管理操作)的交互过程。在PostgreSQL中,外存管理由SMGR(主要代码在smgr.c中)提供了对外存的统一接口。SMGR负责统管各种介质管理器,会根据上层的请求选择一个具体的介质管理器进行操作…

【最优化理论】线性规划标准模型的基本概念与性质

我们在中学阶段就遇到过线性规划问题,主要是二维的情况,而求解的方法一般是非常直观、高效的图解法。根据过往的经验,线性规划问题的最优目标值一般在可行域的顶点处取得,那么本文就对这个问题进行更深入的探讨,维度也…

找不到msvcp140.dll解决方法的5个解决方法以及msvcp140.dll丢失原因分析

msvcp140.dll 是 Microsoft Visual C 2017 Redistributable 的一部分,许多应用程序和游戏都需要这个动态链接库(DLL)才能正常运行。如果您的系统中找不到 msvcp140.dll,您可能会遇到无法打开某些应用程序或游戏的困境。小编将讨论…

运用动态内存实现通讯录(增删查改+排序)

目录 前言: 实现通讯录: 1.创建和调用菜单: 2.创建联系人信息和通讯录: 3.初始化通讯录: 4.增加联系人: 5.显示联系人: 6.删除联系人: ​编辑 7.查找联系人: ​…

nodejs+vue健身服务应用elementui

第三章 系统分析 10 3.1需求分析 10 3.2可行性分析 10 3.2.1技术可行性:技术背景 10 3.2.2经济可行性 11 3.2.3操作可行性: 11 3.3性能分析 11 3.4系统操作流程 12 3.4.1管理员登录流程 12 3.4.2信息添加流程 12 3.4.3信息删除流程 13 第四章 系统设计与…

AWS Lambda Golang HelloWorld 快速入门

操作步骤 以下测试基于 WSL2 Ubuntu 22.04 环境 # 下载最新 golang wget https://golang.google.cn/dl/go1.21.1.linux-amd64.tar.gz# 解压 tar -C ~/.local/ -xzf go1.21.1.linux-amd64.tar.gz# 配置环境变量 PATH echo export PATH$PATH:~/.local/go/bin >> ~/.bashrc …

【小沐学前端】Node.js实现基于Protobuf协议的WebSocket通信

文章目录 1、简介1.1 Node1.2 WebSocket1.3 Protobuf 2、安装2.1 Node2.2 WebSocket2.2.1 nodejs-websocket2.2.2 ws 2.3 Protobuf 3、代码测试3.1 例子1:websocket(html)3.1.1 客户端:yxy_wsclient1.html3.1.2 客户端&#xff1a…

绘制动图,金星木星月亮太阳绕圆

图💫 input绘制 行星 木星 太阳 地球 金星💫 地球 月亮各自旋转 1年 角度 360.gif import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation import math import os# 设置中文字体 font_style 宋体 plt.rcParam…

11Spark

1.安装 anaconda 在官网上下载anaconda linux 后缀为.sh的安装包 运行sh ./Anaconda3-2021.05-Linux-x86_64.sh 安装过程: 输入yes后就安装完成了. 验证: 安装完成后, 退出SecureCRT 重新进来: 看到这个base开头表明安装好了. base是默认的虚拟环…

条件查询和数据查询

一、后端 1.controller层 package com.like.controller;import com.like.common.CommonDto; import com.like.entity.User; import com.like.service.UserService; import jakarta.annotation.Resource; import org.springframework.web.bind.annotation.GetMapping; import …

用于YOLO格式分割的咖啡叶病害数据集。

下载链接:https://download.csdn.net/download/qq_40840797/88389334 数据集,一共1164张照片 随机选取几张照片及对应的目标标签 因为健康,所以标签为空

【嵌入式】使用MultiButton开源库驱动按键并控制多级界面切换

目录 一 背景说明 二 参考资料 三 MultiButton开源库移植 四 设计实现--驱动按键 五 设计实现--界面处理 一 背景说明 需要做一个通过不同按键控制多级界面切换以及界面动作的程序。 查阅相关资料,发现网上大多数的应用都比较繁琐,且对于多级界面的…

十大常见排序算法详解(附Java代码实现和代码解析)

文章目录 十大排序算法⛅前言🌱1、排序概述🌴2、排序的实现🌵2.1 插入排序🐳2.1.1 直接插入排序算法介绍算法实现 🐳2.1.2 希尔排序算法介绍算法实现 🌵2.2 选择排序🐳2.2.1 选择排序算法介绍算…

结构体运算符重载

1.降序 struct Point{int x, y;//重载比较符bool operator < (const Point &a) const{return x > a.x;//当前元素大时&#xff0c;是降序} };2.升序 struct Point{int x, y;//重载比较符 // bool operator < (const Point &a) const{ // return x…

如何初始化一个vue项目

如何初始化一个vue项目 安装 vue-cli 后 ,终端执行 vue ui npm install vue-cli --save-devCLI 服务 | Vue CLI (vuejs.org) 等一段时间后 。。。 进入项目仪表盘 设置其他模块 项目构建后目录 vue.config.js 文件相关配置 官方vue.config.js 参考文档 https://cli.vuejs.o…

【vue3】Suspense组件和动态引入defineAsyncComponent的搭配使用

假期第五篇&#xff0c;对于基础的知识点&#xff0c;我感觉自己还是很薄弱的。 趁着假期&#xff0c;再去复习一遍 在app中定义子组件child //静态引入&#xff0c;网速慢的时候&#xff0c;父子组件也是同时渲染出来 <template><div><h3>APP父组件</…

BI神器Power Query(27)-- 使用PQ实现表格多列转换(3/3)

实例需求&#xff1a;原始表格包含多列属性数据,现在需要将不同属性分列展示在不同的行中&#xff0c;att1、att3、att5为一组&#xff0c;att2、att3、att6为另一组&#xff0c;数据如下所示。 更新表格数据 原始数据表&#xff1a; Col1Col2Att1Att2Att3Att4Att5Att6AAADD…

BI神器Power Query(26)-- 使用PQ实现表格多列转换(2/3)

实例需求&#xff1a;原始表格包含多列属性数据,现在需要将不同属性分列展示在不同的行中&#xff0c;att1、att3、att5为一组&#xff0c;att2、att3、att6为另一组&#xff0c;数据如下所示。 更新表格数据 原始数据表&#xff1a; Col1Col2Att1Att2Att3Att4Att5Att6AAADD…

APP或小程序突然打开显示连接网络失败,内容一片空白的原因是,SSL证书到期啦,续签即可

由于我们使用的是https&#xff0c;所以SSL证书到期了&#xff0c;通过https进入读取内容的APP或网站或小程序就会打开后连接网络失败&#xff0c;出现空白&#xff0c;这是因为我们申请的SSL证书到期了&#xff0c;因为我们申请的证书有效期有时是1个月或3个月&#xff0c;到期…

建筑能源管理(2)——建筑用能分类与计算方法

1、按输入建筑的能源形式分类 根据《民用建筑能耗分类及表示方法》GB/T 34913-2017&#xff0c;建筑用能边界位于建筑入口处(图2.2)&#xff0c;对应为满足建筑各项功能需求从外部输入的电力、燃料、冷/热量及可再生能源等&#xff0c;其中冷热量由外部区域能源系统制备&#…