变形金刚:第 2 部分:变形金刚的架构

news2025/1/11 6:05:24

目录

一、说明

二、实现Transformer的过程

        第 1 步:代币化(Tokenization)

        第 2 步:对每个单词进行标记嵌入

        第 3 步:对每个单词进行位置嵌入

        第 4 步:输入嵌入

第 5 步:编码器层

2.5.1 多头自注意力 

2.5.2 剩余连接(添加)

2.5.3 层归一化

2.5.4 前馈神经网络

第6步:解码器层

三、小结


一、说明

        本文对transformer的体系进行系统梳理,总的来说,transformer实现分五个步骤,除了一般化处理,token、词法、句法、词嵌入式,编码、解码过程。

二、实现Transformer的过程

        涉及的步骤是:

        第 1 步:代币化(Tokenization

        输入序列“how are you”被标记为单个单词或子单词。我们假设它被标记为以下标记:[“how”、“are”、“you”]。

        第 2 步:对每个单词进行标记嵌入

  • 嵌入“如何”:[0.5, 0.2, -0.1]
  • “are”的嵌入:[-0.3, 0.8, 0.4]
  • 嵌入“你”:[0.6,-0.5,0.3]
令牌嵌入和位置嵌入之间的区别
model_name = 'gpt2'  # Replace with the specific transformer model you want to use
model = GPT2Model.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Example input text
input_text = "Hello, how are you today?"

# Tokenize input text
tokens = tokenizer.tokenize(input_text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_tensor = torch.tensor([input_ids])

# Generate token embeddings
token_embeddings = model.transformer.wte(input_tensor)

      

        第 3 步:对每个单词进行位置嵌入

        位置输入嵌入是基于变压器的模型的关键组成部分,包括用于会话聊天机器人的模型。它们提供有关序列中标记相对位置的信息,使模型能够理解输入的顺序。

        在 Transformer 中,位置输入嵌入被添加到令牌嵌入中以对位置信息进行编码。

        生成位置嵌入有不同的方法。一种常见的方法是使用正弦函数来创建具有固定模式的嵌入。序列中每个标记的位置被映射到一个唯一的向量,其中向量的值对应于不同频率的正弦和余弦函数。这些频率控制每个位置嵌入对令牌的最终表示的贡献程度。

# Generate positional embeddings
position_ids = torch.arange(input_tensor.size(1), dtype=torch.long)
position_ids = position_ids.unsqueeze(0)
position_embeddings = model.transformer.wpe(position_ids)

        第 4 步:输入嵌入

        输入嵌入是令牌嵌入和位置嵌入的总和。

# Sum token embeddings and positional embeddings
input_embeddings = token_embeddings + position_embeddings

第 5 步:编码器层

2.5.1 多头自注意力 

        a.含义

        编码器层中的第一个子层是多头自注意力机制。它允许输入序列中的每个位置关注所有其他位置,捕获序列的不同元素之间的依赖关系

        b.自注意力机制

        自注意力机制通过考虑每个元素与所有其他元素的关系来计算每个元素的注意力权重。例如:

        输入:“你好吗”

        在自注意力机制中,每个元素的权重都是相对于其他元素计算的。这就是找到“如何”与“是”、“你”和“做”的关系。这种机制被并行地(头向)多次应用以捕获不同类型的依赖关系。

        问→查询

        K → 键

        V→值

        步骤a:为每个输入嵌入找到查询、键和值

        在自注意力中,从每个输入单词导出三个向量查询向量、键向量和值向量。这些向量用于计算注意力权重。单词的注意力权重表示在对特定单词进行编码时应该注意多少。

        这里的查询、键和值就像 YouTube 搜索引擎一样。给出查询(在搜索框中),像视频标题、描述这样的键用于查找值(这里是视频)。

        步骤b.注意力权重计算(Scores)

        使用 Matmul 找到注意力权重或分数

        分数矩阵将决定每个单词对另一个单词的重要性。分数越高,越受关注。

        每个词都会影响另一个词

        步骤c:在分数矩阵上进行缩放

        步骤 d:对分数矩阵进行 Softmax

        经过softmax之后,较高的分数会得到提高,较低的分数会受到抑制。这使得重要的分数在下一步中被削弱

        步骤 e:MatMul,其中注意力权重与值向量相乘

    

        将注意力权重与值相乘得到输出

        步骤f: 找到每个单词的上下文向量

        当涉及的向量数量较多时。它们连接成一个向量

        创建各种自注意力头

        自注意力头被连接起来并发送到线性块中。

        连接的输出被送入线性层进行处理,并找到每个单词的上下文向量。

2.5.2 剩余连接(添加)

添加输入嵌入和上下文向量

2.5.3 层归一化

剩余连接的输出进入层归一化。分层归一化有助于稳定网络。

“你好吗”的输出向量:

  • “如何”的上下文向量:[0.9,-0.4,0.2]
  • “are”的上下文向量:[-0.2, 0.6, -0.3]
  • “你”的上下文向量:[0.3,0.1,0.5]

2.5.4 前馈神经网络

现在,上下文向量通过位置前馈神经网络传递,该网络独立地对每个位置进行操作。

我们将前馈网络的权重和偏差表示如下:

  • 第一个线性变换权重:W1 = [[0.1, 0.3, -0.2], [0.4, 0.2, 0.5], [-0.3, 0.1, 0.6]]
  • 第一个线性变换偏差:b1 = [0.2, 0.1, -0.3]
  • 第二次线性变换权重:W2 = [[0.2, -0.4, 0.1], [-0.3, 0.5, 0.2], [0.1, 0.2, -0.3]]
  • 第二个线性变换偏差:b2 = [0.1, -0.2, 0.3]

对于每个上下文向量,具有Relu 激活函数的位置前馈神经网络应用以下计算:

位置前馈神经网络将上下文向量转换为每个单词的新表示。

  • 编码器可以堆叠n次,以便每个单词可以学习每个单词的不同表示。

描述编码器的代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        
        self.self_attention = nn.MultiheadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Multi-head self-attention
        attn_output, _ = self.self_attention(x, x, x)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        
        # Feed-forward network
        ff_output = self.feed_forward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)
        
        return x

第6步:解码器层

        转换器解码器从编码器获取编码表示并处理输出序列(“我很好”)以生成相应的序列。

在解码的每个步骤中,解码器都会关注:

  • 之前生成的单词
  • 使用自注意力和编码器-解码器注意力机制的编码输入表示。

解码器

目标:“<开始>I am

步骤 1:解码器接收序列开始标记“<start>”作为其初始输入。

步骤 2:解码器使用编码器-解码器注意机制关注编码的输入表示。这有助于解码器与输入中的相关信息保持一致。

步骤 3:解码器使用自注意力机制关注其之前生成的单词,并考虑到目前为止生成的嵌入。

步骤 4:解码器应用位置前馈神经网络来细化表示。

步骤5:解码器根据细化的表示生成第一个单词“I”。

步骤6-8:解码器对后续单词“am”和“fine”重复该过程,根据之前的上下文、自注意力和前馈网络生成它们。

(代码较为复杂,整理以后给出---更新中.... )

三、小结

        以上概括第给出transformer的实现提纲,至于每一个步骤还有更多细节,可以在此文的引导下继续细化完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1449145.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux makefile 大型多文件的处理

最简单的例子是 main.cpp test.cpp test.h 首先将这三个写好 然后的话 test.cpp 上面输出 helloworld 首先我们在同一个目录下创建一个makefile 文件 然后用vim 编辑它 如下图&#xff08;使用的c&#xff09; mybin 是我们的可执行程序 gcc是编译的命令 gcc 前面必…

[HCIE]vxlan --静态隧道

实验目的:1.pc2与pc3互通&#xff08;二层互通&#xff09;&#xff1b;2.pc1与pc3互通&#xff08;三层互通&#xff09; 实验说明&#xff1a;sw1划分vlan10 vlan20 ;sw2划分vlan30&#xff1b;上行接口均配置为Trunk 实验步骤&#xff1a; 1.配置CE1/CE2/CE3环回口互通&a…

深度学习之反向传播算法

反向传播算法 数学公式算法代码结果 算法中一些函数的区别 数学公式 算法代码 这里用反向传播算法&#xff0c;计算 y w * x模型 import numpy as np import matplotlib.pyplot as ply#反向传播算法&#xff0c;需要使用pytorch框架&#xff0c; #这里导入pytorch框架&#xf…

力扣_面试题:配对交换

配对交换 链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 题目意思就是交换相邻两个二进制位 &#xff0c;用&分别取出even&#xff08;偶位和&#xff09;odd&#xff08;奇位和&#xff09; 偶位和用0xAAAAAAAA&#xff0c;奇…

[数学建模] 计算差分方程的收敛点

[数学建模] 计算差分方程的收敛点 差分方程&#xff1a;差分方程描述的是在离散时间下系统状态之间的关系。与微分方程不同&#xff0c;差分方程处理的是在不同时间点上系统状态的变化。通常用来模拟动态系统&#xff0c;如在离散时间点上更新状态并预测未来状态。 收敛点&…

4核16g云服务器多少钱?

4核16G服务器租用优惠价格26元1个月&#xff0c;腾讯云轻量4核16G12M服务器32元1个月、96元3个月、156元6个月、312元一年&#xff0c;阿腾云atengyun.com分享4核16服务器租用费用价格表&#xff0c;阿里云和腾讯云详细配置报价和性能参数表&#xff1a; 腾讯云4核16G服务器价…

2024年2月份实时获取地图边界数据方法,省市区县街道多级联动【附实时geoJson数据下载】

首先&#xff0c;来看下效果图 在线体验地址&#xff1a;https://geojson.hxkj.vip&#xff0c;并提供实时geoJson数据文件下载 可下载的数据包含省级geojson行政边界数据、市级geojson行政边界数据、区/县级geojson行政边界数据、省市区县街道行政编码四级联动数据&#xff0…

C#使用迭代器显示公交车站点

目录 一、涉及到的知识点 1.迭代器 2.IList接口及实现IList接口的Add方法 二、实例 1.源码 2.生成效果 一、涉及到的知识点 1.迭代器 迭代器是.NET 4.5开始的一个新特性&#xff0c;它是可以返回相同类型的值的有序序列的一段代码。迭代器可用作方法、运算符或get访问器…

Java安全 CC链6分析

CC链6分析 前言CC链分析核心transform链Lazymap类TiedMapEntry类HashMap方法 最终exp 前言 CC链6不受jdk版本与cs版本的影响&#xff0c;在Java安全中最为通用&#xff0c;并且非常简洁&#xff0c;非常有学习的必要&#xff0c;建议在学习CC链6之前先学习一下 URLDNS链 和 CC…

【Python如何通过多种方法输出九九乘法表】

1、九九乘法表方法一&#xff1a; for i in range(1, 10): # 对i在1到9进行循环&#xff08;不包括10&#xff09;for j in range(1, i 1): # 对j在1到i进行循环&#xff08;不包括i&#xff09;print(%d * %d %2d % (j, i, j * i), end ) # 对j,i进行格式化输出&#x…

【C++函数探幽】内联函数inline

&#x1f4d9; 作者简介 &#xff1a;RO-BERRY &#x1f4d7; 学习方向&#xff1a;致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 &#x1f4d2; 日后方向 : 偏向于CPP开发以及大数据方向&#xff0c;欢迎各位关注&#xff0c;谢谢各位的支持 目录 1. 前言2.概念3.特性…

位图

目录 位图的概念 位图的实现 寻找位置 set reset test 面试题 1.给定100亿个整数&#xff0c;设计算法找到只出现一次的整数&#xff1f; 2. 给两个文件&#xff0c;分别有100亿个整数&#xff0c;我们只有1G内存&#xff0c;如何找到两个文件交集&#xff1f; 3. 位…

SpringBoot Starter造了个自动锁轮子

可能有人会有疑问&#xff0c;为什么外面已经有更好的组件&#xff0c;为什么还要重复的造轮子&#xff0c;只能说&#xff0c;别人的永远是别人的&#xff0c;自己不去造一下&#xff0c;就只能知其然&#xff0c;而不知其所以然。&#xff08;其实就为了卷&#xff09; 在日常…

VS Code之Java代码重构和源代码操作

文章目录 支持的代码操作列表调用重构分配变量字段和局部变量的差别Assign statement to new local variable在有参构造函数中将参数指定成一个新的字段 将匿名类转换为嵌套类什么是匿名类&#xff1f;匿名类转换为嵌套类的完整演示 转换为Lambda表达式Lambda 表达式是什么?转…

【OrangePi Zero2 智能家居】智能家居项目的软件实现

一、项目整体设计 二、项目代码的前期准备 三、实现语音监听接口 四、实现socket监听接口 五、实现烟雾报警监听接口 六、实现设备节点代码 七、实现接收消息处理接口 一、项目整体设计 整体的软件框架大致如下&#xff1a; 整个项目开启4个监听线程&#xff0c; 分别是&…

Java常用类与基础API--String的构造器与常用方法

文章目录 一、String的常用API-1&#xff08;1&#xff09;构造器1、介绍2、举例 &#xff08;2&#xff09;String与其他结构间的转换1、基本数据类型、包装类 --> 字符串2、字符串 --> 基本数据类型、包装类3、字符串 --> 字符数组4、字符数组 --> 字符串5、字符…

C++类和对象-C++运算符重载->加号运算符重载、左移运算符重载、递增运算符重载、赋值运算符重载、关系运算符重载、函数调用运算符重载

#include<iostream> using namespace std; //加号运算符重载 class Person { public: Person() {}; Person(int a, int b) { this->m_A a; this->m_B b; } //1.成员函数实现 号运算符重载 Person operator(const Per…

4核16G服务器价格腾讯云PK阿里云

4核16G服务器租用优惠价格26元1个月&#xff0c;腾讯云轻量4核16G12M服务器32元1个月、96元3个月、156元6个月、312元一年&#xff0c;阿腾云atengyun.com分享4核16服务器租用费用价格表&#xff0c;阿里云和腾讯云详细配置报价和性能参数表&#xff1a; 腾讯云4核16G服务器价…

JavaWeb学习|Filter与ThreadLocal

学习材料声明 所有知识点都来自互联网&#xff0c;进行总结和梳理&#xff0c;侵权必删。 引用来源&#xff1a;尚硅谷最新版JavaWeb全套教程,java web零基础入门完整版 Filter 1、Filter 过滤器它是 JavaWeb 的三大组件之一。三大组件分别是&#xff1a;Servlet 程序、Liste…

Oracle数据库自动维护任务(Automated Maintenance Tasks)

Oracle数据库自动维护任务(Automated Maintenance Tasks) Oracle数据库有以下预定义的自动维护任务: Automatic Optimizer Statistics Collection - 收集数据库中没有统计信息或只有过时统计信息的所有模式对象的优化器统计信息。SQL查询优化器使用该任务收集的统计信息来提高…