TR3复现Tramsformer

news2024/11/25 2:30:17
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

前言

Transformer模型是深度学习中的一个革命性架构,自从在NLP领域引入以来,就因其高效处理序列数据的能力而迅速成为主流。本文将通过代码实现详细剖析Transformer模型的各个组件,包括多头注意力机制、前馈神经网络、位置编码、编码器和解码器等部分。

1. Transformer模型简介

Transformer模型由Vaswani等人在2017年提出,首次在自然语言处理任务中完全摆脱了循环神经网络(RNN),依赖于自注意力机制来处理序列数据。它不仅在机器翻译、文本生成等任务中表现优异,还在各种任务中展现了良好的扩展性和性能。

Transformer的核心思想是通过自注意力机制(Self-Attention)来捕捉序列中词与词之间的关系,并通过多头注意力(Multi-Head Attention)和前馈神经网络(Feedforward Neural Network)来进一步处理这些关系。

2. Transformer的核心组件
2.1 多头注意力机制

在Transformer中,多头注意力机制(Multi-Head Attention)是最重要的组件之一。它通过对输入序列的不同部分进行多次并行的注意力计算,从而捕捉到更多的上下文信息。以下是多头注意力机制的实现:

class MultiHeadAttention(nn.Module):
    def __init__(self, hid_dim, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.hid_dim = hid_dim
        self.n_heads = n_heads

        # 确保hid_dim可以被n_heads整除
        assert hid_dim % n_heads == 0

        # 定义线性变换矩阵
        self.w_q = nn.Linear(hid_dim, hid_dim)
        self.w_k = nn.Linear(hid_dim, hid_dim)
        self.w_v = nn.Linear(hid_dim, hid_dim)
        self.fc  = nn.Linear(hid_dim, hid_dim)

        # 缩放因子
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads]))

    def forward(self, query, key, value, mask=None):
        bsz = query.shape[0]
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)

        # 将Q, K, V拆分成多个头
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim // self.n_heads).permute(0, 2, 1, 3)

        # 计算注意力得分
        attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)

        attention = torch.softmax(attention, dim=-1)

        # 计算多头注意力的输出
        x = torch.matmul(attention, V)

        # 拼接多个头的输出
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
        x = self.fc(x)
        return x

在这个实现中,MultiHeadAttention类首先对输入的querykeyvalue进行线性变换,然后将它们拆分为多个注意力头,并分别计算每个头的注意力得分。最后,将所有头的结果拼接起来并通过线性层输出。

2.2 前馈神经网络

前馈神经网络(Feedforward Neural Network)是Transformer中的另一核心组件。它通常由两层线性变换和一个ReLU激活函数组成:

class Feedforward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(Feedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = torch.nn.functional.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

这个前馈网络将每个位置的表示独立地通过一个全连接层映射到更高维空间,再映射回原来的维度,从而增强模型的表达能力。

2.3 位置编码

由于Transformer模型不再使用RNN来处理序列数据,因此需要一种方法让模型感知到输入序列中词的顺序信息。为此,引入了位置编码(Positional Encoding):

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model).to(device)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].requires_grad_(False)
        return self.dropout(x)

位置编码通过正弦和余弦函数生成固定的编码,能够为不同位置的词提供唯一的表示。

3. 编码器和解码器

Transformer模型的编码器和解码器分别由多个层堆叠而成。每一层都包含一个多头注意力机制和一个前馈神经网络。

3.1 编码器层

编码器层的实现如下:

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn   = MultiHeadAttention(d_model, n_heads)
        self.feedforward = Feedforward(d_model, d_ff, dropout)
        self.norm1   = nn.LayerNorm(d_model)
        self.norm2   = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        ff_output = self.feedforward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)

        return x

每个编码器层首先通过多头自注意力机制处理输入序列,然后通过前馈神经网络进一步处理。最后,通过残差连接和LayerNorm层来规范化输出。

3.2 解码器层

解码器层的实现如下:

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn   = MultiHeadAttention(d_model, n_heads)
        self.enc_attn    = MultiHeadAttention(d_model, n_heads)
        self.feedforward = Feedforward(d_model, d_ff, dropout)
        self.norm1   = nn.LayerNorm(d_model)
        self.norm2   = nn.LayerNorm(d_model)
        self.norm3   = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, self_mask, context_mask):
        attn_output = self.self_attn(x, x, x, self_mask)
        x           = x + self.dropout(attn_output)
        x           = self.norm1(x)

        attn_output = self.enc_attn(x, enc_output, enc_output, context_mask)
        x           = x + self.dropout(attn_output)
        x           = self.norm2(x)

        ff_output = self.feedforward(x)
        x = x + self.dropout(ff_output)
        x = self.norm3(x)

        return x

解码器层包含三个部分:自注意力机制、编码器-解码器注意力机制,以及前馈神经网络。解码器层不仅需要关注目标序列的自身信息,还需要从编码器的输出中提取上下文信息。

4. Transformer模型的整体架构

最后,我们将所有组件组合成完整的Transformer模型:

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, n

_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout=0.1):
        super(Transformer, self).__init__()
        self.embedding           = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
        self.encoder_layers      = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers)])
        self.decoder_layers      = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers)])
        self.fc_out              = nn.Linear(d_model, vocab_size)
        self.dropout             = nn.Dropout(dropout)

    def forward(self, src, trg, src_mask, trg_mask):
        src = self.embedding(src)
        src = self.positional_encoding(src)
        trg = self.embedding(trg)
        trg = self.positional_encoding(trg)

        for layer in self.encoder_layers:
            src = layer(src, src_mask)

        for layer in self.decoder_layers:
            trg = layer(trg, src, trg_mask, src_mask)

        output = self.fc_out(trg)

        return output

在这个完整的Transformer模型中,我们首先对源语言和目标语言进行嵌入并加上位置编码。然后,经过多层编码器和解码器的处理,最后通过一个线性层输出最终的预测结果。

结果
# 使用示例
vocab_size = 10000  # 假设词汇表大小为10000
d_model    = 512
n_heads    = 8
n_encoder_layers = 6
n_decoder_layers = 6
d_ff             = 2048
dropout          = 0.1

transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout)

# 定义输入,这里的输入是假设的,需要根据实际情况修改
src = torch.randint(0, vocab_size, (32, 10))  # 源语言句子
trg = torch.randint(0, vocab_size, (32, 20))  # 目标语言句子
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # 掩码,用于屏蔽填充的位置
trg_mask = (trg != 0).unsqueeze(1).unsqueeze(2)  # 掩码,用于屏蔽填充的位置

# 模型前向传播
output = transformer_model(src, trg, src_mask, trg_mask)
print(output.shape)

在这里插入图片描述

5. 总结

本这周通过PyTorch代码实现了Transformer模型的各个核心组件,并详细解释了它们的原理和作用。Transformer模型凭借其高效的自注意力机制和并行处理能力,已经成为自然语言处理领域的标准工具。理解其内部原理,不仅有助于更好地应用这个模型,还能为进一步的改进和创新提供坚实的基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1992734.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

谷粒商城实战笔记-nginx问题记录

记录在使用nginx中遇到的问题。 文章目录 1,网关路由匹配不生效2,网关路由配置前后顺序导致的问题(非nginx问题)3,nginx.conf upstream配置缺少端口4,配置结尾少分号5, proxy_pass 后跟的服务器 URL 是否以 / 结尾5.1 …

C语言学习:汉诺塔问题

汉诺塔_百度百科 (baidu.com)https://baike.baidu.com/item/%E6%B1%89%E8%AF%BA%E5%A1%94/3468295 // // Created by zzh on 2024/8/6. ////汉诺塔问题#include<stdio.h>void move(char x, char y) {printf("%c --> %c \n", x, y); }int hanoi(int count, i…

2024新版软件测试八股文及答案解析

前言 前面看到了一些面试题&#xff0c;总感觉会用得到&#xff0c;但是看一遍又记不住&#xff0c;所以我把面试题都整合在一起&#xff0c;都是来自各路大佬的分享&#xff0c;为了方便以后自己需要的时候刷一刷&#xff0c;不用再到处找题&#xff0c;今天把自己整理的这些…

WEB渗透未授权访问篇-Redis

测试 redis-cli redis-cli -h 127.0.0.1 flunshall 192.168.0.110:6379>ping PONG 存在未授权访问 JS打内网 var cmd new XMLHttpRequest(); cmd.open("POST", "http://127.0.0.1:6379"); cmd.send(flushall\r\n); var c…

51单片机之LED篇(二)独立按键

一、独立按键的介绍 1.1 独立按键的基本原理 相当于一种电子开关&#xff0c;按下时开关接通&#xff0c;松开时开关断开。 开关功能&#xff1a;独立按键内部通常包含一个有弹性的金属片&#xff0c;当按键被按下时&#xff0c;金属片与触点接触&#xff0c;电路连通&#x…

鸿蒙AI功能开发【hiai引擎框架-语音识别】 基础语音服务

hiai引擎框架-语音识别 介绍 本示例展示了使用hiai引擎框架提供的语音识别能力。 本示例展示了对一段音频流转换成文字的能力展示。 需要使用hiai引擎框架文本转语音接口kit.CoreSpeechKit.d.ts. 效果预览 使用说明&#xff1a; 在手机的主屏幕&#xff0c;点击”asrDemo…

CMake基础教程二

常用 环境变量 SET(ENV{VAR} VALUE)**常用变量&#xff1a;**| 变量名 | 含义 | | ----------------------------- | ---------------------------------------------------------…

Bitwise 首席投资官:忽略短期的市场波动,关注加密货币的发展前景

原文标题&#xff1a;《The Crypto Market Sell-Off: What Happened and Where We Go From Here》撰文&#xff1a;Matt Hougan&#xff0c;Bitwise 首席投资官编译&#xff1a;Chris&#xff0c;Techub News 加密货币市场在周末经历了大幅下跌。从上周五下午 4 点到周一早上 7…

2024年下软考报名全流程+备考指南(八月最新版)

2024年下半年软考备考&#xff0c;一定要知道这几点&#xff01; 2024年下半年软考报名已迫在眉睫&#xff0c;不知不觉间&#xff0c;留给下半年考试小伙伴们的复习时间只有三个月。备考的小伙伴们准备好了吗&#xff1f;这些全程重点&#xff0c;请务必收藏保存&#xff0c;…

C/C++数字与字符串互相转换

前言&#xff1a; 在C/C程序中&#xff0c;会需要把数字与字符串做出互相转换的操作&#xff0c;用于实现程序想要的效果。下面将介绍多种方法实现数字与字符串互相转换。 字符串转为数字 一、利用ASCII 我们知道每个字符都有一个ASCII码&#xff0c;利用这一点可以将字符-0…

vue文件style标签变成黄色,media query is expected

效果如下图所示&#xff0c;红色波浪线&#xff0c;鼠标放上去提示 media query is expected 对比其他文件后发现是引入scss文件后后面少了分号&#xff0c;导致报错&#xff0c;加上分号&#xff0c;效果如下图&#xff0c;完美解决~

文件操作常用函数及makefile的使用

文件操作中常用函数 1. getpwuid 定义: struct passwd *getpwuid(uid_t uid);功能: 根据用户ID&#xff08;UID&#xff09;返回与之对应的passwd结构体指针&#xff0c;该结构体包含用户的详细信息。常用字段: pw_name: 用户名。pw_uid: 用户ID。pw_gid: 用户的组ID。pw_dir…

Qt实现类似淘宝商品看板的界面,带有循环翻页以及点击某页跳转的功能

效果如下&#xff1a; #ifndef ModelDashboardGroup_h__ #define ModelDashboardGroup_h__#include <QGridLayout> #include <QLabel> #include <QPushButton> #include <QWidget>#include <QLabel> #include <QWidget> #include <QMou…

Jenkins保姆笔记(3)——Jenkins拉取Git代码、编译、打包、远程多服务器部署Spring Boot项目

前面我们介绍过&#xff1a; Jenkins保姆笔记&#xff08;1&#xff09;——基于Java8的Jenkins安装部署 Jenkins保姆笔记&#xff08;2&#xff09;——基于Java8的Jenkins插件安装 本篇主要介绍基于Java8的Jenkins第一个Hello World项目&#xff0c;一起实践下Jenkins拉…

第十九节 大语言模型与多模态大模型loss计算

文章目录 前言一、大语言模型loss计算1、loss计算代码解读2、构建模型输入内容与label标签二、多模态大模型loss计算方法1、多模态loss计算代码解读2、多模态输入内容2、大语言模型输入内容3、图像embending如何嵌入文本embeding前言 如果看了我前面文章,想必你基本对整个代码…

Java学习Day24:基础篇14:多线程

1.程序、进程和线程 程序 进程 进程(process)是程序的一次执行过程&#xff0c;或是一个正在执行的程序。是一个动态的过程&#xff1a;有它自身的产 生、存在和消亡的过程。 如&#xff1a; 运行中的QQ运行中的音乐播放器视频播放器等&#xff1b;程序是静态的&#xff0c…

写给小白程序员的一封信

文章目录 1.编程小白如何成为大神&#xff1f;大学新生的最佳入门攻略2.程序员的练级攻略3.编程语言的选择4.熟悉Linux5.学会git6.知道在哪寻求帮助7.多结交朋友8.参加开源项目9.坚持下去 1.编程小白如何成为大神&#xff1f;大学新生的最佳入门攻略 编程已成为当代大学生的必…

音视频开发,最新学习心得与感悟

音视频技术的知识海洋浩瀚无垠&#xff0c;自学之路显得尤为崎岖&#xff0c;技术门槛的存在是毋庸置疑的事实。 对于渴望踏入这一行业的初学者而言&#xff0c;学习资源的匮乏成为了一道难以逾越的障碍。 本次文章主要是给大家分享音视频开发进阶学习路线&#xff0c;虽然我…

三大口诀不一样的代码,小小的制表符和换行符玩的溜呀

# 小案例&#xff0c;打印输出加法口诀 for i in range(1,10):for j in range(1,10):if j>i:breakprint(f"{j}{i}{ji}".strip(),end\t)print() print(\n) for i in range(1,10):for j in range(1,10):if j>i:breakprint(f"{j}x{i}{j*i}",end\t)print…

[Spring] Spring AOP

&#x1f338;个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 &#x1f3f5;️热门专栏: &#x1f9ca; Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 &#x1f355; Collection与…