【机器学习】一文搞懂算法模型之:Transformer

news2024/11/15 8:48:25

Transformer

  • 1、引言
  • 2、Transformer
    • 2.1 定义
    • 2.2 原理
    • 2.3 算法公式
      • 2.3.1 自注意力机制
      • 2.3.1 多头自注意力机制
      • 2.3.1 位置编码
    • 2.4 代码示例
  • 3、总结

1、引言

小屌丝:鱼哥, 你说transformer是个啥?
小鱼:嗯… 啊… 嗯…就是…
小屌丝:你倒是说啊,是个啥?
小鱼:你不知道?
小屌丝:我知道啊,
小鱼:你知道,你还问我?
小屌丝:考考你
小鱼:kao…kao… wo??
小屌丝:对
小鱼:看你你很懂 transformer
小屌丝: 必须的
小鱼:那你来跟我说一说?
小屌丝:可以。
小鱼:秃头顶上点灯
小屌丝:啥意思? 你说的是啥意思?你说清楚,啥意思???

在这里插入图片描述
小鱼:唉~ ~没文化,真可怕。

2、Transformer

2.1 定义

Transformer是一种基于自注意力机制的深度学习模型,它完全摒弃了传统的循环神经网络(RNN)和卷积神经网络(CNN)结构,转而通过自注意力机制来计算输入序列中不同位置之间的依赖关系。

Transformer模型由谷歌在2017年提出,并在自然语言处理领域取得了显著的成绩,如机器翻译、文本生成、问答系统等。

2.2 原理

Transformer的核心思想是利用自注意力机制来捕捉输入序列中的上下文信息。
它采用Encoder-Decoder架构,其中Encoder负责将输入序列编码为一系列的隐藏状态,而Decoder则根据这些隐藏状态生成输出序列。

在Transformer中,自注意力机制是通过计算输入序列中每个位置与其他位置之间的相似度来实现的。

具体来说,对于输入序列中的每个位置,Transformer会生成一个查询向量、一个键向量和一个值向量。然后,它会计算查询向量与所有键向量之间的相似度,得到一个相似度分数矩阵。

最后,根据这个相似度分数矩阵对值向量进行加权求和,得到每个位置的输出表示。

2.3 算法公式

2.3.1 自注意力机制

自注意力机制

  • 输入:查询矩阵Q、键矩阵K、值矩阵V
  • 输出:自注意力输出Z
  • 公式 Z = s o f t m a x ( Q K T / √ d k ) V Z = softmax(QK^T/√d_k)V Z=softmax(QKT/√dk)V

其中,d_k是键向量的维度,用于缩放点积结果以防止梯度消失。

2.3.1 多头自注意力机制

多头自注意力机制

  • 输入:查询矩阵Q、键矩阵K、值矩阵V、头数h
  • 输出:多头自注意力输出MultiHead(Q, K, V)
  • 公式 M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中,每个头head_i都执行一次自注意力机制,W^O是输出线性变换的权重矩阵。

2.3.1 位置编码

位置编码
为了捕捉序列中的位置信息,Transformer使用位置编码将位置信息嵌入到输入向量中。

公式 P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 ( 2 i / D ) ) P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 ( 2 i / D ) ) PE(pos, 2i) = sin(pos / 10000^(2i / D)) PE(pos, 2i+1) = cos(pos / 10000^(2i / D)) PE(pos,2i)=sin(pos/10000(2i/D))PE(pos,2i+1)=cos(pos/10000(2i/D))

其中, p o s pos pos是位置索引, i i i是维度索引, D D D是输入向量的维度。

2.4 代码示例

# -*- coding:utf-8 -*-
# @Time   : 2024-03-17
# @Author : Carl_DJ

'''
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
# 定义位置编码类  
class PositionalEncoding(nn.Module):  
    def __init__(self, d_model, max_len=5000):  
        super(PositionalEncoding, self).__init__()  
          
        # 计算位置编码  
        pe = torch.zeros(max_len, d_model)  
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *  
                                 -(np.log(torch.tensor(10000.0)) / d_model))  
        pe[:, 0::2] = torch.sin(position * div_term)  
        pe[:, 1::2] = torch.cos(position * div_term)  
        pe = pe.unsqueeze(0).transpose(0, 1)  
        self.register_buffer('pe', pe)  
          
    def forward(self, x):  
        x = x + self.pe[:, :x.size(1)]  
        return x  
  
# 定义Transformer编码器层  
class TransformerEncoderLayer(nn.Module):  
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):  
        super(TransformerEncoderLayer, self).__init__()  
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  
        self.linear1 = nn.Linear(d_model, dim_feedforward)  
        self.dropout = nn.Dropout(dropout)  
        self.linear2 = nn.Linear(dim_feedforward, d_model)  
          
        self.norm1 = nn.LayerNorm(d_model)  
        self.norm2 = nn.LayerNorm(d_model)  
        self.dropout1 = nn.Dropout(dropout)  
        self.dropout2 = nn.Dropout(dropout)  
          
    def forward(self, src, src_mask=None, src_key_padding_mask=None):  
        # 自注意力机制  
        src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,  
                                      key_padding_mask=src_key_padding_mask)  
        src = src + self.dropout1(src2)  
        src = self.norm1(src)  
          
        # 前馈神经网络  
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))  
        src = src + self.dropout2(src2)  
        src = self.norm2(src)  
          
        return src, attn  
  
# 定义Transformer编码器  
class TransformerEncoder(nn.Module):  
    def __init__(self, encoder_layer, num_layers, norm=None):  
        super(TransformerEncoder, self).__init__()  
        self.layers = nn.ModuleList([encoder_layer for _ in range(num_layers)])  
        self.num_layers = num_layers  
        self.norm = norm  
          
    def forward(self, src, mask=None, src_key_padding_mask=None):  
        output = src  
          
        # 堆叠多个编码器层  
        for mod in self.layers:  
            output, attn = mod(output, mask, src_key_padding_mask)  
          
        if self.norm:  
            output = self.norm(output)  
          
        return output  
  
# 定义完整的Transformer模型  
class TransformerModel(nn.Module):  
    def __init__(self, src_vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_position_embeddings):  
        super(TransformerModel, self).__init__()  
        self.src_mask = None  
        self.position_enc = PositionalEncoding(d_model, max_position_embeddings)  
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout=0.1)  
        encoder_norm = nn.LayerNorm(d_model)  
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, norm=encoder_norm)  
		self.src_embedding = nn.Embedding(src_vocab_size, d_model)  
    	self.d_model = d_model  
      
def generate_square_subsequent_mask(self, sz):  
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)  
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))  
    return mask  
  
def forward(self, src, src_mask=None):  
    if src_mask is None:  
        device = src.device  
        bsz, seq_len = src.size()  
        src_mask = self.generate_square_subsequent_mask(seq_len).to(device)  
      
    src = self.src_embedding(src) * math.sqrt(self.d_model)  
    src = self.position_enc(src)  
    output = self.encoder(src, src_mask)  
    output = self.decoder(output)  
      
    return output


代码解析

  • PositionalEncoding:定义位置编码类,用于给输入序列添加位置信息。Transformer模型是位置无关的,因此需要通过位置编码来提供序列中每个位置的信息。

  • TransformerEncoderLayer:定义Transformer编码器层,包含自注意力机制和前馈神经网络。每个编码器层都会接收输入序列,经过自注意力机制和前馈神经网络处理,输出新的序列表示。

  • TransformerEncoder:定义Transformer编码器,通过堆叠多个编码器层来构建更深的模型。

  • TransformerModel:定义完整的Transformer模型,包括嵌入层、位置编码、编码器和解码器。模型将源语言序列作为输入,经过一系列变换后,输出每个位置的目标语言词汇的预测概率。

  • generate_square_subsequent_mask:生成一个上三角矩阵的掩码,用于在自注意力机制中防止模型看到未来的信息。

  • forward:模型的前向传播函数。首先,将源语言序列通过嵌入层转换为向量表示,并乘以模型维度的平方根进行缩放。然后,将位置编码加到嵌入向量上。接下来,将结果输入到编码器中。最后,将编码器的输出传递给解码器,得到每个位置的预测概率。

在这里插入图片描述

3、总结

Transformer模型通过自注意力机制打破了传统RNN和CNN的局限性,能够更好地捕捉序列中的长期依赖关系,因此在自然语言处理任务中取得了显著的效果。

Transformer的Encoder-Decoder架构和多头自注意力机制使得模型能够同时处理输入序列中的多个位置,提高了模型的并行性和效率。通过位置编码,Transformer还能够处理序列中的位置信息,使得模型在处理自然语言等序列数据时更加灵活和准确。

虽然Transformer模型具有强大的能力,但它也需要大量的数据和计算资源来进行训练。

因此,在实际应用中,我们通常使用预训练的Transformer模型(如BERT、GPT等)作为基础,通过微调来适应特定的任务需求。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)测评一、二等奖获得者

关注小鱼,一起学习机器学习&深度学习领域的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1536354.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Rust基本类型

数值类型 整数类型 无符号整数只能取正数和0,有符号整数可以取正数负数和0。isize和usize类型取决于程序运行的计算机CPU类型,若CPU是32位的,则这两个类型是32位的,若CPU是64位的,则它们是64位的。rust整型 默认使用…

CentOS/RHEL 6.5 上 NFS mount 挂起kernel bug

我本身有四台机器做WAS集群,挂载nfs,其中随机一台客户端计算机端口关闭释放将进入不良状态,对 NFSv4 挂载的任何访问都将挂起(例如“ls,cd 或者df均挂起”)。这意味着没有人并且所有需要访问共享的用户进程…

C语言例:设 int x; 则表达式 (x=4*5,x*5),x+25 的值

代码如下&#xff1a; #include<stdio.h> int main(void) {int x,m;m ((x4*5,x*5),x25);printf("(x4*5,x*5),x25 %d\n",m);//x4*520//x*5100//x2545return 0; } 结果如下&#xff1a;

wy的leetcode刷题记录_Day92

wy的leetcode刷题记录_Day92 声明 本文章的所有题目信息都来源于leetcode 如有侵权请联系我删掉! 时间&#xff1a;2024-3-22 前言 目录 wy的leetcode刷题记录_Day92声明前言2617. 网格图中最少访问的格子数题目介绍思路代码收获 695. 岛屿的最大面积题目介绍思路代码收获 2…

图论必备:Dijkstra、Floyd与Bellman-Ford算法在最短路径问题中的应用

&#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;アンビバレント—Uru 0:24━━━━━━️&#x1f49f;──────── 4:02 &#x1f504; ◀️ ⏸ ▶️ ☰ …

Javaweb学习记录(二)web开发入门(请求响应)

第一个基于springboot的web请求程序 通过创建一个带有springboot的spring项目&#xff0c;项目会自动生成一个程序启动类&#xff0c;该类启动时会启动该整个项目&#xff0c;而我们需要写一个web请求类&#xff0c;要求在本地浏览器上发送请求后&#xff0c;浏览器显示Hello&…

python --- 练习题3

目录 1、猜数字游戏&#xff08;使用random模块完成&#xff09; &#xff1a;继上期题目&#xff0c;附加 2、用户登录注册案例 3、求50~150之间的质数是那些&#xff1f; 4、打印输出标准水仙花数&#xff0c;输出这些水仙花数 5、验证:任意一个大于9的整数减去它的各位…

【数据库系统】数据库完整性和安全性

第六章 数据库完整性和安全性 基本内容 安全性&#xff1b;完整性&#xff1b;数据库恢复技术&#xff1b;SQL Server的数据恢复机制&#xff1b; 完整性 实体完整性、参照完整性、用户自定义完整性 安全性 身份验证权限控制事务日志&#xff0c;审计数据加密 数据库恢复 冗余…

中国贸易金融跨行交易区块链平台CTFU、区块链福费廷交易平台BCFT、中国人民银行贸易金融区块链平台CTFP、银行函证区块链服务平台BPBC

中国人民银行贸易金融区块链平台CTFP介绍 贸易金融的发展概况及存在的问题 1.1 贸易金融的概况 贸易金融是指商业银行在贸易双方债权债务关系的基础上&#xff0c;为国内或跨国的商品和服务贸易提供的贯穿贸易活动整个价值链、全程全面性的综合金融服务。伴随全球化的进程&a…

互联网思维:息共享、开放性、创新和快速反应、网络化、平台化、数据驱动和用户体验 人工智能思维:模拟人、解放劳动力、人工智能解决方案和服务

互联网思维&#xff1a;信息共享、开放性、创新和快速反应、网络化、平台化、数据驱动和用户体验 互联网思维是指一种以互联网为基础的思考方式&#xff0c;强调信息共享、开放性、创新和快速反应的特点。这种思维方式注重网络化、平台化、数据驱动和用户体验&#xff0c;以适…

simulink里枚举量的使用--在m文件中创建枚举量实践操作(推荐)

本文将介绍一种非常重要的概念&#xff0c;枚举量&#xff0c;以及它在simulink状态机中的使用&#xff0c;并且给出模型&#xff0c;方便大家学习。 枚举量&#xff1a;实际上是用一个名字表示了一个变量&#xff0c;能够比较方便的表示标志信息 A.简单举例&#xff1a; 1&a…

Hack The Box-Analytics

目录 信息收集 namp whatweb WEB 信息收集 feroxbuster RCE漏洞 提权 get user get root 信息收集 namp 端口信息探测┌──(root㉿ru)-[~/kali/hackthebox] └─# nmap -p- 10.10.11.233 --min-rate 10000 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-03-…

经典双指针问题

思路;先找到第一个包含m家店的区间&#xff08;l-r&#xff09;&#xff0c;然后开始进行双指针&#xff08;l&#xff0c;r&#xff09;滑动(如下滑动) while(r<n){while(vis[a[l]]>1)//当前l-r之间a[l]店铺有多个&#xff08;大于一个&#xff09;&#xff0c;那即可去…

macOS下Java应用的打包和安装程序制作

macOS应用程序结构 macOS通常以dmg或pkg作为软件发行包&#xff0c;安装到/Applications下后&#xff0c;结构比较统一。 info.plist里的CFBundleExecutable字段可以指定入口&#xff0c;如果不指定&#xff0c;则MacOS下必须存在同名可执行文件。即abc.app下必须存在abc.app/…

从原理到实践:深入探索Linux安全机制(一)

前言 本文将从用户和权限管理、文件系统权限、SELinux、防火墙、加密和安全传输、漏洞管理和更新等几个Linux安全机制中的重要方面&#xff0c;深入探索其工作原理和使用方法。在当今数字化时代&#xff0c;网络安全问题备受关注&#xff0c;Linux作为广泛应用的操作系统之一&…

【GPT概念04】仅解码器(only decode)模型的解码策略

一、说明 在我之前的博客中&#xff0c;我们研究了关于生成式预训练转换器的整个概述&#xff0c;以及一篇关于生成式预训练转换器&#xff08;GPT&#xff09;的博客——预训练、微调和不同的用例应用。现在让我们看看所有仅解码器模型的解码策略是什么。 二、解码策略 在之前…

财报解读:“高端化”告一段落,华住开始“全球化”?

2023年旅游业快速复苏&#xff0c;全球酒店业直接受益&#xff0c;总体运营指标大放异彩&#xff0c;多数酒店企业都实现了营收上的明显增长&#xff0c;身为国内龙头的华住也不例外。 3月20日晚&#xff0c;华住集团发布2023年四季度及全年财报。整体实现扭亏为盈&#xff0c;…

阿里云安装宝塔后面板打不开

前言 按理来说装个宝塔面板应该很轻松的&#xff0c;我却装了2天&#xff0c;真挺恼火的&#xff0c;网上搜的教程基本上解决不掉我的问题点&#xff0c;问了阿里云和宝塔客服&#xff0c;弄了将近2天&#xff0c;才找出问题出在哪里&#xff0c;在此记录一下问题的处理。 服…

深度探析:7天后不过期的微信群二维码生成的优势

在日常生活和工作中&#xff0c;微信不过期二维码深受用户的欢迎。因为传统的微信群二维码被下载下来后&#xff0c;只有7天有效期。但企业在日常运营中&#xff0c;如果直接使用下载下来的微信群二维码&#xff0c;会造成很多的不便和宣传资源浪费。这些问题&#xff0c;可以通…

华为ensp中ospf基础 原理及配置命令(详解)

CSDN 成就一亿技术人&#xff01; 作者主页&#xff1a;点击&#xff01; ENSP专栏&#xff1a;点击&#xff01; CSDN 成就一亿技术人&#xff01; ————前言———— OSPF 的全称是 Open Shortest Path First&#xff0c;意为“开放式最短路径优先”。是一种内部网关协…