【深度学习】最强算法之:Transformer

news2024/11/26 10:28:57

Transformer

  • 1、引言
  • 2、Transformer
    • 2.1 引言
    • 2.2 核心概念
    • 2.3 应用
    • 2.4 算法公式
    • 2.5 代码示例
  • 3、总结

1、引言

小屌丝:鱼哥,昨天的感受咋样
小鱼:啥感受啊?
小屌丝:你确定不知道?
小鱼:我…
小屌丝:再想一想… 嘿嘿~
小鱼:嗯… 确实值得回味。
小屌丝:那咱再去体验体验?
小鱼:这… 不太好吧。
小屌丝:那有啥不太好的。
在这里插入图片描述

小鱼:我可是正经人,体验一次就够了,可不能上瘾的
小屌丝: 那你不去,我可去了。反正都约好了
小鱼:额… 都约好了,那不去岂不是浪费了。
小屌丝:但是,咱的有个条件
小鱼:啥条件?
小屌丝:就是,跟我讲一讲 Transformer
小鱼: 这岂不是张口就来。
小屌丝:那讲讲。
小鱼:路上讲。
小屌丝:也就这个时候,你最积极。

2、Transformer

2.1 引言

Transformer模型是一种深度学习算法,自2017年由Google的研究者Vaswani等人在论文《Attention is All You Need》中首次提出以来,它迅速成为了自然语言处理(NLP)领域的一种革命性技术。

Transformer模型的核心思想是利用“自注意力(Self-Attention)”机制来处理序列数据,这使得它能够在处理长距离依赖问题时表现得更加出色。

2.2 核心概念

  • 自注意力(Self-Attention):
    • 自注意力机制允许模型在处理每个序列元素时,考虑到序列中的所有其他元素,这样能够捕获序列内的长距离依赖关系。
    • 自注意力机制通过计算序列中每个元素对其他所有元素的注意力分数来实现,这些分数表示了元素间的相关性强度。
  • 多头注意力(Multi-Head Attention):
    • Transformer模型使用多头注意力机制来增强模型的注意力能力。
    • 通过并行地执行多个自注意力操作,模型可以从不同的表示子空间中学习信息,这样能够让模型更好地理解和处理数据。
  • 位置编码(Positional Encoding):
    • 由于Transformer完全基于注意力机制,没有像循环神经网络(RNN)那样的递归结构来处理序列数据,因此需要一种方式来理解序列中元素的位置信息。
    • Transformer通过给输入元素添加位置编码来实现这一点,位置编码与元素的嵌入向量相加,从而让模型能够利用序列的顺序信息。
  • 编码器-解码器架构:
    • Transformer模型采用编码器-解码器架构。
      • 编码器负责处理输入序列,解码器则负责生成输出序列。
      • 在编码器和解码器中,都使用了多层自注意力和全连接网络。
        在这里插入图片描述

2.3 应用

Transformer模型已经被广泛应用于各种自然语言处理任务中,包括但不限于:

  • 机器翻译
  • 文本摘要
  • 问答系统
  • 文本生成

2.4 算法公式

Transformer模型的核心是自注意力机制,其关键的计算公式如下:

  • 自注意力得分计算
    Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
    其中,

    • Q Q Q K K K V V V分别代表查询(Query)、键(Key)和值(Value),
    • d k d_k dk是键的维度。

    这个公式通过计算查询和所有键之间的相似度,然后用softmax归一化得到权重分布,最后用这个分布来加权值。

  • 多头注意力

    • Transformer采用多头注意力机制来并行地执行多次自注意力操作,从而能够让模型从不同的子空间学习信息。
    • 多头注意力的计算可以表示为: MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO
    • 其中,每个 head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV) W O W^O WO是输出线性变换的权重矩阵。
  • 位置编码

    • 位置编码的公式如下,用于给每个位置的元素编码一个唯一的位置信息: P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d model ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) PE(pos,2i)=sin(100002i/dmodelpos) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d model ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)
    • 其中, p o s pos pos是位置序号, i i i是维度序号, d model d_{\text{model}} dmodel是模型的维度。

2.5 代码示例

# -*- coding:utf-8 -*-
# @Time   : 2024-04-06
# @Author : Carl_DJ

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
  
class MultiHeadAttention(nn.Module):  
    def __init__(self, d_model, num_heads):  
        super(MultiHeadAttention, self).__init__()  
        # 初始化参数  
        self.num_heads = num_heads  
        self.d_k = d_model // num_heads  
          
        # 定义线性层用于生成查询(Q)、键(K)和值(V)  
        self.queries_linear = nn.Linear(d_model, d_model)  
        self.keys_linear = nn.Linear(d_model, d_model)  
        self.values_linear = nn.Linear(d_model, d_model)  
          
        # 初始化最后的线性层,用于将多头注意力结果合并回原始维度  
        self.fc_out = nn.Linear(d_model, d_model)  
          
        # 定义缩放因子  
        self.scale = 1 / (self.d_k ** 0.5)  
      
    def forward(self, values, keys, query, mask=None):  
        # 将输入通过线性层得到Q、K、V  
        N = query.shape[0]  
        query = self.queries_linear(query).view(N, -1, self.num_heads, self.d_k).transpose(1, 2)  
        keys = self.keys_linear(keys).view(N, -1, self.num_heads, self.d_k).transpose(1, 2)  
        values = self.values_linear(values).view(N, -1, self.num_heads, self.d_k).transpose(1, 2)  
          
        # 计算注意力分数  
        energy = torch.einsum("nqhd,nkhd->nhqk", [query, keys]) * self.scale  
          
        # 如果提供了掩码,则将其应用到注意力分数上  
        if mask is not None:  
            energy = energy.masked_fill(mask == 0, float("-1e20"))  
          
        # 通过softmax计算注意力权重  
        attention = torch.softmax(energy / (self.d_k ** 0.5), dim=3)  
          
        # 应用注意力权重到值上,得到输出  
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).transpose(1, 2).contiguous().view(N, -1, self.d_model)  
          
        # 通过最后的线性层得到最终输出  
        out = self.fc_out(out)  
          
        return out  
  
# 假设我们有以下输入维度和头数  
d_model = 512  
num_heads = 8  
  
# 实例化自注意力层  
self_attn = MultiHeadAttention(d_model, num_heads)  
  
# 创建一些模拟数据  
batch_size = 64  
seq_length = 10  
values = torch.randn(batch_size, seq_length, d_model)  
keys = torch.randn(batch_size, seq_length, d_model)  
query = torch.randn(batch_size, seq_length, d_model)  
  
# 创建掩码(例如,为了处理填充)  
mask = torch.ones(batch_size, 1, seq_length).bool()  
  
# 将数据输入自注意力层进行前向传播  
output = self_attn(values, keys, query, mask=mask)  
  
print(output.shape)  # 输出形状应该是 (batch_size, seq_length, d_model)


步骤

  • 首先,定义了一个MultiHeadAttention类,并创建了一个自注意力层的实例self_attn;
  • 然后,创建了一些模拟数据values、keys、query以及一个掩码mask;
  • 最后,我们调用self_attn的forward方法,将模拟数据作为输入进行前向传播,得到输出output;

3、总结

Transformer模型通过其独特的自注意力机制解决了序列处理中的长距离依赖问题,其多头注意力和位置无关的特性使其在处理各种序列数据时都能获得卓越的表现。

由于其强大的性能和灵活的结构,Transformer及其衍生模型已经成为了自然语言处理领域的基石,影响了整个人工智能领域的发展方向。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)测评一、二等奖获得者

关注小鱼,学习【机器学习】&【深度学习】领域的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1604388.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软考 - 系统架构设计师 - 设计模式

概念 每一个设计模式描述了一个在我们周围不断重复发生的问题,以及该问题解决方案的核心,这样,就可以在遇到相同的问题时使用该解决方案进行解决,不必进行重复的工作,设计模式的核心在于提供了问题的解决方案&#xff…

代码学习记录25---单调栈

随想录日记part45 t i m e : time: time: 2024.04.17 主要内容:今天开始要学习单调栈的相关知识了,今天的内容主要涉及:每日温度 ;下一个更大元素 I 739. 每日温度 496.下一个更大元素 I Topic…

数仓建模—逻辑数据模型

数仓建模—逻辑数据模型 逻辑数据模型 (LDM Logical Data Model) 是一种详细描述数据元素的数据模型,用于开发对数据实体、属性、键和关系的直观理解。 这种模型独特地独立于特定的数据库,以便为数据管理系统中语义层的组件建立基础结构。将 LDM 视为一个蓝图:它代表在整个…

软件无线电安全之GNU Radio基础 -上

GNU Radio介绍 GNU Radio是一款开源的软件工具集,专注于软件定义无线电(SDR)系统的设计和实现。该工具集支持多种SDR硬件平台,包括USRP、HackRF One和RTL-SDR等。用户可以通过GNU Radio Companion构建流程图,使用不同…

记录Windows XP系统安装详细图文版安装日志

一、准备工作 一、下载镜像文件 我用的镜像文件在网盘可自行下载: 点击下载OSI镜像文件提取码:888999 系统安装介质准备:首先,你需要准备一个Windows XP的安装介质。这可以是光盘,也可以是U盘。确保你的安装介质是…

MyBatis 源码分析 - SQL 的执行过程

MyBatis 源码分析 - SQL 的执行过程 * 本文速览 本篇文章较为详细的介绍了 MyBatis 执行 SQL 的过程。该过程本身比较复杂,牵涉到的技术点比较多。包括但不限于 Mapper 接口代理类的生成、接口方法的解析、SQL 语句的解析、运行时参数的绑定、查询结果自动映射、延…

Linux(磁盘管理与文件系统)

目录 1. 磁盘基础 1.1 磁盘结构 1.2 MBR 1.3 磁盘分区结构 2. 文件系统类型 2.1 XFS文件系统 2.2 SWAP 2.3 fdisk命令 2.4 创建新硬盘 3.创建文件系统 3.1 mkfs 3.2 挂载、卸载文件系统 3.3 查看磁盘使用情况 1. 磁盘基础 1.1 磁盘结构 磁盘的物理结构 盘片:硬…

Java项目实现图形验证码(Hutool)

项目架构: 使用SpringCloudmysqlmybatis-plus需要将数据库中的数据导出到Excel文件中 前端为Vue2 业务场景: 登录时使用验证码登录 1.1 打开hutool, 搜索 图片验证码 1.2后端编写生产验证码方法 1.3前端 1.3.1展示验证码 1.3.2 前端方法 1.3.2.1UU…

快速入门Web开发(上) 黑马程序员JavaWeb开发教程

快速入门Web开发(上) 本文档是黑马程序员的 黑马程序员JavaWeb开发教程,实现javaweb企业开发全流程(涵盖SpringMyBatisSpringMVCSpringBoot等)_哔哩哔哩_bilibili上这篇没有写什么很深的个人见解 但下篇有 开发Web网…

02 MySQL --DQL专题--条件查询、函数、分组查询

一些盲点. 数据库中仅有月薪字段(month_salary),要求查询所有员工的年薪,并以年薪(year_salary)输出: 分析: 查询操作中,字段可以参与数学运算as 起别名,但实际上可以省略 #以下…

Stronghold Village

有了近2000个预制件和大量资产,您可以用基本的或先进的模块化预制件建造您的设防城镇或梦幻村庄,其中有许多定制选项和大量道具和物品 【资产描述】 你准备好建造你的史诗般的奇幻设防小镇了吗?有了这个庞大的资产库,您将能够创建村庄、城市、据点、乡村建筑、大教堂、城堡…

Tricentis测试生成式人工智能系统和红队:入门指南

Tricentis测试生成式人工智能系统和红队:入门指南 测试人工智能并确保其责任、安全和保障的话题从未如此紧迫。自 2021 年以来,人工智能滥用的争议和事件增加了26 倍,凸显了日益增长的担忧。用户很快就会发现,人工智能工具并非万无一失。他们可能会表现出过度自信,并且缺…

JavaWeb--04YApi,Vue-cli脚手架Node.js环境搭建,创建第一个Vue项目

04 1 Yapi2 Vue-cli脚手架Node.js环境搭建配置npm的全局安装路径 3 创建项目(这个看下一篇文章吧) 1 Yapi 前后端分离中的重要枢纽"接口文档",以下一款为Yapi的接口文档 介绍:YApi 是高效、易用、功能强大的 api 管理平台&#…

使用vite从头搭建一个vue3项目(四)使用axios封装request.js文件,并使用proxy解决跨域问题

目录 一、创建request.js文件二、创建axios实例三、创建请求、响应拦截器四、使用 request.js,测试接口:https://api.uomg.com/api/rand.qinghua1、调取接口代码书写2、注意(跨域问题) axios 的二次封装有三个要点: 创…

解决程序化刷新EXCEL提示更新外部链接的弹窗问题

解决方法 【信任中心】-> 【消息栏】->勾选如下策略提示 2. 【信任中心】->【外部内容】->启用下面的三项链接 3. 【信任中心】->【宏设置】->启用所有宏

【Python小游戏】植物大战僵尸的实现与源码分享

文章目录 Python版植物大战僵尸环境要求方法源码分享初始化页面(部分)地图搭建(部分)定义植物类 (部分)定义僵尸类(部分)游戏运行入口 游戏源码获取 Python版植物大战僵尸 已有的植…

MySQL-笔记-06.数据高级查询

目录 6.1 连接查询 6.1.1 交叉连接(cross join) 6.1.2 内连接(inner join) 6.1.3 外连接(outer join) 6.1.3.1 左外连接(left [outer] join) 6.1.3.2 右外连接(rig…

Meta因露骨AI图片陷入困境

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Postgresql源码(126)TupleStore使用场景与原理分析

相关 《Postgresql源码(125)游标恢复执行的原理分析》 《Postgresql游标使用介绍(cursor)》 总结 开源PG中使用tuple store来缓存tuple集,默认使用work_mem空间存放,超过可以落盘。在PL的returns setof场景…

回文链表leecode

回文链表 偶数情况奇数情况 回文链表leecode 偶数情况 public boolean isPalindrome(ListNode head) {if (head null) {return true;}ListNode fast head;ListNode slow head;while (fast ! null && fast.next ! null) {fast fast.next.next;slow slow.next;}//反…