Ai 算法之Transformer 模型的实现: 一 、Input Embedding模块和Positional Embedding模块的实现

news2024/12/23 11:04:54

一 文章生成模型简介

比较常见的文章生成模型有以下几种:

  1. RNN:循环神经网络。可以处理长度变化的序列数据,比如自然语言文本。RNN通过隐藏层中的循环结构来传递时间序列中的信息,从而使当前的计算可以参照之前的信息。但这种模型有梯度爆炸和梯度消失的风险,所以只能做简单的生成任务。
  2. LSTM:长短记忆网络。通过引入门控制机制来控制信息传递。有效避免了梯度消失和梯度保障的问题。LSTM可以做些复杂的生成任务。
  3. Transformer:目前最火的,一种基于自注意力机制(self-attention mechanism)的神经网络模型。Transformer 和 以上所述的几个生成模型主要的区别是,RNN、LSTM的训练迭代是串行的,必须要处理完当前字才可以处理下一个。而 Transformer 所有字符是同时训练的,也就是并行的。因此它效率更高,同样,由于参考了全文位置信息,因此效果更好。

值得一提的是这几个模型的价值并不仅限于在文章生成中。所有需要"经验值"的应用场景应该都适合借鉴。比如19年我曾尝试用LSTM来实现物联网小车自动驾驶。将操作指令转换为文字编码,实现了自动巡航、避障、撞墙倒车等操作。效果还不错。相信更换为注意力机制效果会更好

本文无意重塑轮子,纯是基于兴趣学习,尝试复现模型构造过程,本文所使用环境为python3.9+pytorch,参考论文为Google的Attention Is All You Need 2017。欢迎骚扰探讨

关于RNN和LSTM的实现代码,请查看我博客中的相关文章

1.1 Transformer 结构图

左侧为外国原版,右侧为在下翻译版
请添加图片描述
Transformer 模型主要分为两大部分,分别是 Encoder 、 Decoder,即组码器和解码器。组码器负责把输入语言序列映射成隐藏层,然后解码器再把隐藏层映射为其他自然语言序列。在原文中解码器和编码器都被设为6层(N = 6)。据说这个6没有特殊的含义。只是根据经验平衡了训练和精度的尝试数字。
在输入语句进入组码器前需要对数据进行预处理。这就是本章的主要内容:Embedding模块的实现

二 Input Embedding 字符编码模块的实现

字符编码本质上就相当于映射,将现实中的物体用数学的方式映射到计算机中。以翻译任务为例,我们需要准备两种不同的语言数据,并使用索引将他们一一对应。比如英文字符[i, eat, shit], 中文[我,吃,屎],这就相当于我们知道了问题和答案,剩下的就是训练隐藏层的参数了。

在npl中,为了使字符可以计算,首先要先将输入的词汇进行数学转化。在比较在其的语言处理中,一般使用one hot(独热)编码。即指定一个表值范围数组,单独改变某个位置上的值来决定其特征。
独热编码示例:
[1,0 ,0 ,0] = 我
[0,1 ,0 ,0] = 吃
[0,0, 1 ,0] = 屎
独热编码简单清晰,但无法对比两个值之间的相似性,无法进行降维操作。所以在tranfomer中 使用多维向量来表示单词的编码信息。一个向量表示一个单词。多个单词在一起就是一个矩阵。相比较以前的独热编码,词向量可以便于计算单词之间的相似性(点积),也可以进行降维操作。
单词向量示例:
[11,23,31,32]
[23,21,31,23]
[13,32,33,93]

单词的 Embedding 有很多种方式可以获取,例如可以采用 Word2Vec、Glove 等算法预训练得到,也可以在 Transformer 中训练得到。以下是使用pythoch获取Embedding向量的代码脚本,复制可用。

import torch
import torch.nn as nn

# padding:当句子长度不一,有空白时用0补缺
embedding = nn.Embedding(单词数量, 向量维度,padding=0)
# 根据索引获取8个单词向量
input = torch.LongTensor([[1, 2, 3, 4], [11, 12, 13, 13]])
print(embedding(input))
print(embedding(input).shape)

三 Positional Embedding 位置编码模块的实现

位置编码模块负责将输入序列中的位置信息写入词向量,输入到transformer中的句子没有顺序信息,因此需要通过计算句子的长度,单词长度以及单词所在的位置通过编码来为输入系列添加位置信息。Tranformer原文作者使用的是正弦余弦编码

位置 Embedding 用 PE表示,PE 的维度与单词 Embedding 是一样的。PE 可以通过训练得到,也可以使用某种公式计算得到。在 Transformer 中采用了后者,计算公式如下:

那么单词向量是怎么得来的呢?
单词向量 = 原始单词编码 + 单词位置编码
举个例子:我吃屎 = i eat shit

在这里插入图片描述
位置编码计算公式

偶数索引: P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d ) 偶数索引:PE(pos,2i)=sin(pos/10000^2i/d) 偶数索引:PE(pos,2i)=sin(pos/100002i/d)
单数索引: P E ( p o s , 2 i ) = c o s ( p o s / 1000 0 2 i / d ) 单数索引:PE(pos,2i)=cos(pos/10000^2i/d) 单数索引:PE(pos,2i)=cos(pos/100002i/d)

import torch
import torch.nn as nn
import ludash as ld
import cv2
import seaborn    
import matplotlib.pyplot as plt

term = (10000**2/i)
pe[:, 0::2] = torch.sin(position * term )
pe[:, 1::2] = torch.cos(position * term )

四 获取预处理数据

获取到字符编码和位置编码后,就可以计算出参考了字符位置的权重矩阵

公式: [ q , k , v ] = ( I n p u t E m b e d d i n g + p o s i t i o n a l E m b e d d i n g ) ∗ [ W q , W k , W v ] 公式: [q, k, v] =(Input Embedding + positional Embedding)* [Wq, Wk, Wv] 公式:[q,k,v]=InputEmbedding+positionalEmbedding[Wq,Wk,Wv]
q = 查询向量, k = 键值向量, v = 值向量 q = 查询向量,k = 键值向量,v = 值向量 q=查询向量,k=键值向量,v=值向量






取得这个值后就可以进行下一步:传入Transfrom的组码器进行组码处理了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1307185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一个算法一个例题教会你算法---0-1背包问题

动态规划 0-1背包问题 0-1背包问题就是求在有重量限制的情况下如何装入价值最大的物品。 啥也别说,直接看题: 现在有四个可以放的物品,w代表重量,v代表价值。 step1: 我们列一个背包重量 j 从0到5的表格&#xff0…

【STM32】STM32学习笔记-LED闪烁 LED流水灯 蜂鸣器(06-2)

00. 目录 文章目录 00. 目录01. GPIO之LED电路图02. GPIO之LED接线图03. LED闪烁程序示例04. LED闪烁程序下载05. LED流水灯接线图06. LED流水灯程序示例07. 蜂鸣器接线图08. 蜂鸣器程序示例09. 下载10. 附录 01. GPIO之LED电路图 电路图示例1 电路图示例2 02. GPIO之LED接线图…

Navicat 技术指引 | 适用于 GaussDB 分布式的数据查看器

Navicat Premium(16.3.3 Windows 版或以上)正式支持 GaussDB 分布式数据库。GaussDB 分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能,还提供强大的高阶功能(如模型、结…

CNN发展史脉络 概述图整理

CNN发展史脉络概述图整理,学习心得,供参考,错误请批评指正。 相关论文: LeNet:Handwritten Digit Recognition with a Back-Propagation Network; Gradient-Based Learning Applied to Document Recogniti…

leaflet使用热力图报L找不到的问题ReferenceError: L is not defined at leaflet-heat.js:11:3

1.在main.js中直接引入会显示找不到L 2.解决办法 直接在组件中单独引入使用 可以直接显示出来。 至于为什么main中不能引入为全局,我是没找到,我的另外一个项目可以,新项目不行,不知哪里设置的问题

在linux服上使用nginx+tomcat部署若依前后端分离版本(RuoYi-Vue)

一、先拉工程,地址:RuoYi-Vue: 🎉 基于SpringBoot,Spring Security,JWT,Vue & Element 的前后端分离权限管理系统,同时提供了 Vue3 的版本 二、在window上用idea打开跑通,可参考…

GPT-4V 在保险行业的应用

在科技的进步中,人工智能与大数据技术的结合产生了巨大的能量,推动了各行各业的创新与变革。OpenAI,作为全球领先的人工智能研发机构,在今年的9月25日,以一种崭新的方式,升级了其旗下的GPT-4模型。这次的升…

解决msvcr120.dll文件丢失问题

项目场景: 在VMware虚拟机中安装win7家庭版系统,安装MySQL数据库,部署项目文件。 问题描述 安装MySQL数据库过程中提示“msvcr120.dll文件丢失”。 原因分析: 提示丢失msvcr120.dll文件,我们首先要到C:\Windows\Sys…

SD-WAN解决外贸企业网络问题

为了获取全球客户,占领更多的市场,越来越多的外贸企业出现。外贸企业在发展业务的过程中会遇到很多困难,海外网络访问问题就是其中之一。目前该问题主要有三种解决方案:VPN、MPLS专线以及SD-WAN专线。 VPN通过在公网上面建立专用网…

3、ollvm移植

github: https://github.com/obfuscator-llvm/obfuscator/tree/llvm-4.0 先复制 include Obfuscation: /home/nowind/llvm/ollvm/obfuscator/include/llvm/Transforms/Obfuscation /home/nowind/llvm/llvm-project-9.0.1/llvm/include/llvm/Transforms/Obfuscation lib Ob…

【算法Hot100系列】两数相加

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

深度学习|词嵌入的演变

文本嵌入,也称为词嵌入,是文本数据的高维、密集向量表示,可以测量不同文本之间的语义和句法相似性。它们通常是通过在大量文本数据上训练 Word2Vec、GloVe 或 BERT 等机器学习模型来创建的。这些模型能够捕获单词和短语之间的复杂关系&#x…

利用poi实现将数据库表字段信息导出到word中

研发文档对于开发人员来说都不陌生了,而研发文档里重要的一部分就是表结构设计,需要我们在word建个表格把我们数据库中的表字段信息填进去,表多的话靠我们手动去填非常累人!!! 因此作为开发人员可不可以写…

HarmonyOS—实现UserDataAbility

UserDataAbility接收其他应用发送的请求,提供外部程序访问的入口,从而实现应用间的数据访问。Data提供了文件存储和数据库存储两组接口供用户使用。 文件存储 开发者需要在Data中重写FileDescriptoropenFile(Uriuri,Stringmode)方法来操作文件&#xf…

记账本选择标签选择时间,计算器---记录一下

html部分 <template><view class"pages-main"><!-- 标题栏 --><!-- #ifndef MP-TOUTIAO --><view class"" :style"height:barHeight px;"></view><!-- #endif --><!-- #ifdef MP-TOUTIAO -->&…

zookeeper1==zookeeper源码阅读,源码启动ZK集群

下载源码 Tags apache/zookeeper GitHub https://codeload.github.com/apache/zookeeper/zip/refs/tags/release-3.9.1 JDK8 MAVEN3.8.6 mvn -DskipTeststrue package 配置ZK1 zkServer.cmd中指出了启动类是 QuorumPeerMain QuorumPeer翻译成集群成员比较合理&#xf…

Nacos配置Mysql数据库

目录 前言1. 配置2. 测试前言 关于Nacos的基本知识可看我之前的文章: Nacos基础版 从入门到精通云服务器 通过docker安装配置Nacos 图文操作以下Nacos的版本为1.1.3 1. 配置 对应的配置文件路径如下: 对应的application.properties为配置文件 需配置端口号 以及 mysql中的…

Mysql的事务日志

Mysql的事务具有四个特性&#xff1a;原子性、一致性、隔离性、持久性。那么事务的四种特性分别是靠什么机制实现的呢&#xff1f; 事务的隔离性由锁机制来保证 事务的原子性、一致性、持久性则由redo log和Undo log来保证。 - redo log是重做日志&#xff0c;提供再写入操作&…

docker-compose安装nacos和msql

docker-compose安装nacos和msql 前言前提已经安装docker-compose&#xff0c;如果没有安装&#xff0c;则可以查看上面系列文章中的安装教程。并且文章中使用的是mobaxterm连接虚拟机。 1、下载2、创建并运行 前言 前提已经安装docker-compose&#xff0c;如果没有安装&#x…

BugKu-Web-Flask_FileUpload(模板注入与文件上传)

Flask Flask是一个使用Python编写的轻量级Web应用框架。它是一个微型框架&#xff0c;因为它的核心非常简单&#xff0c;但可以通过扩展来增加其他功能。Flask的核心组件包括Werkzeug&#xff0c;一个WSGI工具箱&#xff0c;以及Jinja2&#xff0c;一个模板引擎。 Flask使用BSD…