Tiny Transformer:从零开始构建简化版Transformer模型

news2024/12/27 13:23:20
引言

        自然语言处理(NLP)与计算机视觉(CV)有显著差异,各自任务的独特性决定了它们适用的模型架构。在CV中,卷积神经网络(CNN)长期占据主导地位,而在NLP领域,循环神经网络(RNN)和长短期记忆网络(LSTM)曾是主流。然而,这些传统模型在处理长序列时效率较低,难以捕捉长期依赖关系。

        针对这些问题,Vaswani等人在2017年提出了一种全新的、完全基于注意力机制的模型——Transformer。该模型解决了RNN串行计算的效率问题,并通过自注意力机制有效处理了长序列的长期依赖问题。本文将带领大家一步步构建一个简化版的Transformer模型,称之为Tiny Transformer,帮助大家深入理解其工作原理。

1. 注意力机制

        Transformer的核心是注意力机制,它通过计算Query、Key和Value之间的相关性,动态地为不同位置分配注意力权重。我们将通过多头注意力机制(Multi-Head Attention)来扩展这种计算,以便模型能同时关注多个不同的相关性。

1.1 什么是Attention?

        Attention机制通过计算Query(查询向量)与Key(键向量)之间的相似度来为Value(值向量)加权求和。它的本质是根据当前输入的每个词与其他词的相关性动态调整注意力分布。

        例如,给定一个句子,我们可以通过Attention机制来计算每个词对其他词的关注程度。Attention公式如下:

1.2 Multi-Head Attention

        多头注意力机制扩展了单头注意力的概念,通过并行化多个注意力头来捕获序列中不同层次的相关性。每个注意力头对输入进行独立的Attention计算,然后将所有头的输出拼接起来,形成最终的输出。

import torch.nn as nn
import torch
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        
        assert self.head_dim * num_heads == d_model, "d_model must be divisible by num_heads"

        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.fc_out = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        
        attn_scores = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(k.size(-1)))
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, T, C)
        return self.fc_out(attn_output)
2. 编码器和解码器

        Transformer的结构包括编码器(Encoder)和解码器(Decoder),二者均由多层的注意力机制和前馈神经网络(Feed-Forward Neural Network, FFN)组成。

2.1 编码器

        编码器的主要任务是对输入序列进行编码,并生成上下文表示供解码器使用。每个编码器层包括一个自注意力层和一个前馈网络。

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, ff_hidden_dim, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, d_model)
        )
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        attn_output = self.mha(x)
        x = self.layernorm1(x + self.dropout(attn_output))
        
        ffn_output = self.ffn(x)
        return self.layernorm2(x + self.dropout(ffn_output))
2.2 解码器

        解码器的结构与编码器类似,但它包含了一个额外的“交叉注意力”层,用于将编码器的输出作为上下文信息输入,结合解码器自身的输入进行生成。

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, ff_hidden_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, d_model)
        )
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.layernorm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out):
        attn_output1 = self.mha1(x)
        x = self.layernorm1(x + self.dropout(attn_output1))
        
        attn_output2 = self.mha2(x, enc_out, enc_out)
        x = self.layernorm2(x + self.dropout(attn_output2))
        
        ffn_output = self.ffn(x)
        return self.layernorm3(x + self.dropout(ffn_output))
3. 位置编码

        Transformer由于完全摒弃了递归结构,不能自然捕捉输入序列中的位置信息。因此,位置编码(Positional Encoding)被引入,用于为每个词添加位置信息。位置编码通过正弦和余弦函数为不同位置生成独特的表示。

import math
import torch

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
4. 完整的Transformer模型

        有了上面各个模块后,我们可以将它们组合成一个完整的Transformer模型。该模型包括一个嵌入层、多个编码器层、解码器层以及一个线性层用于生成输出。

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, ff_hidden_dim, dropout):
        super(Transformer, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, ff_hidden_dim, dropout) for _ in range(num_encoder_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, ff_hidden_dim, dropout) for _ in range(num_decoder_layers)])
        
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        src = self.positional_encoding(self.src_embedding(src))
        tgt = self.positional_encoding(self.tgt_embedding(tgt))
        
        for layer in self.encoder_layers:
            src = layer(src)
        
        for layer in self.decoder_layers:
            tgt = layer(tgt, src)
        
        return self.fc_out(tgt)
结语

        本文通过逐步实现简化版的Transformer,展示了Transformer模型的核心组成部分——多头注意力、编码器-解码器架构、位置编码等。通过这些模块,Transformer能够高效处理序列数据,实现并行计算,广泛应用于自然语言处理、机器翻译等任务。

        Transformer的灵活性和强大的性能使其成为现代深度学习的基石。在掌握了这些基本模块后,大家可以进一步研究更复杂的模型,如BERT、GPT等预训练模型,以更好地理解和应用Transformer在实际任务中的强大能力。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186361.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于微信小程序的四六级词汇+ssm(lw+演示+源码+运行)

摘 要 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,四六级词汇小程序被用户普遍使用,为方便用户能…

Python入门--判断语句

目录 1. 布尔类型和比较运算符 2. if语句的基本格式 3. if-else语句 4. if-elif-else语句 5. 判断语句的嵌套 6. 应用--猜数字游戏 进行逻辑判断,是生活中常见的行为。同样,在程序中,进行逻辑判断也是最为基础的功能。 1. 布尔类型和比…

OceanBase—02(入门篇——对于单副本单节点,由1个observer扩容为3个observer集群)——之前的记录,当初有的问题未解决,目前新版未尝试

OceanBase—02(入门篇——对于单副本单节点,由1个observer扩容为3个observer集群)——之前的记录,有的问题未解决,新版未尝试 1、前言—安装单副本单节点集群1.1 docker安装OB 2、查看现有集群情况2.1 进入容器&#x…

设置服务器走本地代理

勾选: 然后: git clone https://github.com/rofl0r/proxychains-ng.git./configure --prefix/home/wangguisen/usr --sysconfdir/home/wangguisen/etcmakemake install# 在最后配置成本地代理地址 vim /home/wangguisen/etc/proxychains.confsocks4 17…

Python编写的贪吃蛇小游戏

安装包 pip install pygame完整代码 import pygame import randompygame.init()# 定义颜色 white (255, 255, 255) black (0, 0, 0) red (213, 50, 80) green (0, 255, 0) blue (50, 153, 213)# 定义屏幕大小 dis_width 800 dis_height 600dis pygame.display.set_mo…

【数据结构】什么是平衡二叉搜索树(AVL Tree)?

🦄个人主页:修修修也 🎏所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 目录 📌AVL树的概念 📌AVL树的操作 🎏AVL树的插入操作 ↩️右单旋 ↩️↪️右左双旋 ↪️↩️左右双旋 ↪️左单旋 🎏AVL树的删…

CTF刷题buuctf

[WUSTCTF2020]颜值成绩查询 拿到相关题目,其实根据功能和参数分析。需要传入一个学号然后进行针对于对应的学号进行一个查询,很可能就会存在sql注入。 其实这道题最难的点,在于过滤了空格,因此我们使用 /**/来过滤空格的限制。…

智能化焊接数据管理系统:系统功能设计与应用场景,OEM定制

在快速发展的工业4.0时代,智能化技术正以前所未有的速度改变着各行各业,其中焊接行业也不例外。随着物联网、大数据、人工智能等技术的不断融合,智能化焊接数据管理系统应运而生,成为提高焊接效率、保障焊接质量、优化生产流程的重…

半监督学习与数据增强(论文复现)

半监督学习与数据增强(论文复现) 本文所涉及所有资源均在传知代码平台可获取 文章目录 半监督学习与数据增强(论文复现)概述算法原理核心逻辑效果演示使用方式 概述 本文复现论文提出的半监督学习方法,半监督学习&…

C题(二)字符串转数字 --- atoi

———————————————————**目录**—————————————————— 一、 atoi函数介绍 功能函数原型使用示例 二、题解之一 三、留言 问题引入👉 输入样例👉 5 01234 00123 00012 00001 00000 输出样例👉 1234 123 …

‌文件名称与扩展名:批量重命名的技巧与指南

在日常的文件管理中,我们经常需要处理大量的文件,这些文件可能有着各种各样的名称和扩展名。为了更好地管理和识别这些文件,批量重命名成为了一项非常实用的技能。能够帮助我们快速整理文件,提高工作效率。本文将深入探讨文件名称…

vue2圆形标记(Marker)添加点击事件不弹出信息窗体(InfoWindow)的BUG解决

目录 一、问题详情 二、问题排查 三、解决方案 一、问题详情 地图上面的轨迹点希望能通过点击看到详细的经纬度信息,但是点击的时候就是显示不出来。 二、问题排查 代码都是参考高德的官方文档,初步看没有问题啊,但是点击事件就感觉失效…

10.3今日错题解析(软考)

目录 前言计算机网络——路由配置数据库系统——封锁协议 前言 这是用来记录我备考软考设计师的错题的,今天知识点为路由配置、封锁协议,大部分错题摘自希赛中的题目,但相关解析是原创,有自己的思考,为了复习&#xf…

Pix2Pix实现图像转换

tutorials/application/source_zh_cn/generative/pix2pix.ipynb MindSpore/docs - Gitee.com Pix2Pix概述 Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Ph…

Comparable接口和Comparator接口

前言 Java中基本数据类型可以直接比较大小,但引用类型呢?同时引用对象中可能存在多个可比较的字段,那么我们该怎么比较呢? Java中引用类型不能直接进行大小的比较,这种行为在编译器看来是危险的,所以会编译…

程序员在AI时代的生存指南:打造不可替代的核心竞争力

在这个AI大行其道的时代,似乎每天都有新的语言模型像变魔术一样涌现出来,比如ChatGPT、midjourney、claude等等。这些家伙不仅会聊天,还能帮忙写代码,让程序员们感受到了前所未有的“压力”。我身边的一些程序员朋友开始焦虑&…

SpringCloud入门(十)统一网关Gateway

一、网关的作用 Spring Cloud Gateway 是 Spring Cloud 的一个全新项目,该项目是基于 Spring 5.0,Spring Boot 2.0 和 Project Reactor 等响应式编程和事件流技术开发的网关,它旨在为微服务架构提供一种简单有效的统一的 API 路由管理方式。 …

E. Tree Pruning Codeforces Round 975 (Div. 2)

原题 E. Tree Pruning 解析 本题题意很简单, 思路也很好想到, 假设我们保留第 x 层的树叶, 那么对于深度大于 x 的所有节点都要被剪掉, 而深度小于 x 的节点, 如果没有子节点深度大于等于 x, 那么也要被删掉 在做这道题的时候, 有关于如何找到一个节点它的子节点能通到哪里,…

关于鸿蒙next 调用系统权限麦克风

使用app的时候都清楚,想使用麦克风、摄像头,存储照片等,都需要调用系统的权限,没有手机操作系统权限你也使用不了app所提供的功能,虽然app可以正常打开,但是你需要的功能是没办法使用的。今天把自己在鸿蒙学…

想怎样书写HTML5自结束标签,您随意就好(✪▽✪)

书写后接斜杠还是不接,看过ai给的详细解析就不再迷茫了。 (笔记模板由python脚本于2024年10月03日 10:42:41创建,本篇笔记适合HTML5标签的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖…