TR2 - Transformer模型的复现

news2025/1/1 22:12:22

目录

  • 理论知识
  • 模型结构
    • 结构分解
      • 黑盒
      • 两大模块
      • 块级结构
      • 编码器的组成
      • 解码器的组成
  • 模型实现
    • 多头自注意力块
    • 前馈网络块
    • 位置编码
    • 编码器
    • 解码器
    • 组合模型
    • 最后附上引用部分
  • 模型效果
  • 总结与心得体会


理论知识

Transformer是可以用于Seq2Seq任务的一种模型,和Seq2Seq不冲突。

模型结构

模型整体结构

结构分解

黑盒

以机器翻译任务为例
黑盒

两大模块

在Transformer内部,可以分成Encoder编码器和和Decoder解码器两部分,这也是Seq2Seq的标准结构。
两大模块

块级结构

继续拆解,可以发现模型的由许多的编码器块和解码器块组成并且每个解码器都可以获取到最后一层编码器的输出以及上一层解码器的输出(第一个当然是例外的)。
块组成

编码器的组成

继续拆解,一个编码器是由一个自注意力块和一个前馈网络组成。
编码器的组成

解码器的组成

而解码器,是在编码器的结构中间又插入了一个Encoder-Decoder Attention层。
解码器的组成

模型实现

通过前面自顶向下的拆解,已经基本掌握了模型的总体结构。接下来自底向上的复现Transformer模型。

多头自注意力块

class MultiHeadAttention(nn.Module):
    """多头注意力模块"""
    def __init__(self, dims, n_heads):
        """
        dims: 每个词向量维度
        n_heads: 注意力头数
        """
        super().__init__()

        self.dims = dims
        self.n_heads = n_heads

        # 维度必需整除注意力头数
        assert dims % n_heads == 0
        # 定义Q矩阵
        self.w_Q = nn.Linear(dims, dims)
        # 定义K矩阵
        self.w_K = nn.Linear(dims, dims)
        # 定义V矩阵
        self.w_V = nn.Linear(dims, dims)

        self.fc = nn.Linear(dims, dims)
        # 缩放
        self.scale = torch.sqrt(torch.FloatTensor([dims//n_heads])).to(device)

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        # 例如: [32, 1024, 300] 计算10头注意力
        Q = self.w_Q(query)
        K = self.w_K(key)
        V = self.w_V(value)

        # [32, 1024, 300] -> [32, 1024, 10, 30] 把向量重新分组
        Q = Q.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.dims//self.n_heads).permute(0, 2, 1, 3)

        # 1. 计算QK/根dk
        # [32, 1024, 10, 30] * [32, 1024, 30, 10] -> [32, 1024, 10, 10] 交换最后两维实现乘法
        attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

        if mask is not None:
            # 将需要mask掉的部分设置为很小的值
            attention = attention.masked_fill(mask==0, -1e10)
        # 2. softmax
        attention = torch.softmax(attention, dim=-1)

        # 3. 与V相乘
        # [32, 1024, 10, 10] * [32, 1024, 10, 30] -> [32, 1024, 10, 30]
        x = torch.matmul(attention, V)

        # 恢复结构
        # 0 2 1 3 把 第2,3维交换回去
        x = x.permute(0, 2, 1, 3).contiguous()
        # [32, 1024, 10, 30] -> [32, 1024, 300]
        x = x.view(batch_size, -1, self.n_heads*(self.dims//self.n_heads))
        # 走一个全连接层
        x = self.fc(x)
        return x

前馈网络块

class FeedForward(nn.Module):
    """前馈传播"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x

位置编码

class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # 用来存位置编码的向量
        pe = torch.zeros(max_len, d_model).to(device)
        # 准备位置信息
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)* -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        # 注册一个不参数梯度下降的模型参数
        self.register_buffer('pe', pe)

    def forward(self, x):
        x  = x + self.pe[:, :x.size(1)].requires_grad_(False)
        return self.dropout(x)

编码器

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()

        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.feedforward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        ff_output = self.feedforward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)

        return x

解码器

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.enc_attn = MultiHeadAttention(d_model, n_heads)
        self.feedforward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, mask, enc_mask):
        # 自注意力
        attn_output = self.self_attn(x, x, x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        # 编码器-解码器注意力
        attn_output = self.enc_attn(x, enc_output, enc_output, enc_mask)
        x = x + self.dropout(attn_output)
        x = self.norm2(x)

        # 前馈网络
        ff_output = self.feedforward(x)
        x = x + self.dropout(ff_output)
        x = self.norm3(x)

        return x

组合模型

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_encoder_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_decoder_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, trg, src_mask, trg_mask):
        # 词嵌入
        src = self.embedding(src)
        src = self.positional_encoding(src)
        trg = self.embedding(trg)
        trg = self.positional_encoding(trg)

        # 编码器
        for layer in self.encoder_layers:
            src = layer(src, src_mask)

        # 解码器
        for layer in self.decoder_layers:
            trg = layer(trg, src, trg_mask, src_mask)

        output = self.fc_out(trg)

        return output

最后附上引用部分

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

模型效果

编写代码测试模型的复现是否正确(没有跑任务)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

vocab_size = 10000
d_model = 512
n_heads = 8
n_encoder_layers = 6
n_decoder_layers = 6
d_ff = 2048
dropout = 0.1

transformer_model = Transformer(vocab_size, d_model, n_heads, n_encoder_layers, n_decoder_layers, d_ff, dropout).to(device)

src = torch.randint(0, vocab_size, (32, 10)).to(device) # 源语言
trg = torch.randint(0, vocab_size, (32, 20)).to(device) # 目标语言

src_mask = (src != 0).unsqueeze(1).unsqueeze(2).to(device)
trg_mask = (trg != 0).unsqueeze(1).unsqueeze(2).to(device)

output = transformer_model(src, trg, src_mask, trg_mask)
print(output.shape)

打印结果

torch.Size([32, 20, 10000])

说明模型正常运行了

总结与心得体会

我是从CV模型学到Transfromer来的,通过对Transformer模型的复现我发现:

  • 类似于残差的连接在Transformer中也十分常见,还有先缩小再放大的Bottleneck结构。
  • 整个Transformer模型的核心处理对特征的维度没有变化,这一点和CV模型完全不同。
  • Transformer的核心是多头自注意机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1554755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开源免费软件推荐PhotoPrism:一款基于TensorFlow的开源照片管理工具,实现自动图像分类与本地化部署

引言: PhotoPrism,这款基于机器学习软件Google TensorFlow的开源照片管理工具,不仅实现了自动图像分类,更能够精准检测图片的颜色、色度、亮度、质量等属性。无论是全景照片还是地理位置信息,它都能轻松识别。更重要的…

k8s-jenkins安装与流水线

k8s-jenkins安装与流水线 一、环境安装1.创建目录2.后台启动服务3.浏览器访问4.修改密码 二、流水线1.新建流水线任务2.运行流水线3.安装插件4.安装Kubernetes CLI 三、总结 一、环境安装 如果使用的是阿里云Kubernetes集群 ,可以安装其 ack-jenkins应用。 5分钟在…

KVM:尝试安装windows2008

最终目的是在lxd部署windows2008镜像 WindowsServer2008镜像: cn_windows_server_2008_r2_standard_enterprise_datacenter_and_web_with_sp1_x64_dvd_617598.iso 镜像参考链接: https://discussion.scottibyte.com/t/migrate-a-hyper-v-windows-vir…

on-my-zsh 命令自动补全插件 zsh-autosuggestions 安装和配置

首先 Oh My Zsh 是什么? Oh My Zsh 是一款社区驱动的命令行工具,正如它的主页上说的,Oh My Zsh 是一种生活方式。它基于 zsh 命令行,提供了主题配置,插件机制,已经内置的便捷操作。给我们一种全新的方式使用命令行。…

【unity】unity安装及路线图

学习路线图 二、有关unity的下载 由于unity公司是在国外,所以.com版(https://developer.unity.cn/)不一定稳定,学习时推荐从.cn国内版(https://developer.unity.cn/)前往下载,但是后期仍需回…

【全套源码教程】基于SpringBoot+MyBatis框架的智慧生活商城系统的设计与实现

目录 前言 需求分析 可行性分析 技术实现 后端框架:Spring Boot 持久层框架:MyBatis 前端框架:Vue.js 数据库:MySQL 功能介绍 前台功能拓展 商品详情单管理 个人中心 秒杀活动 推荐系统 评论与评分系统 后台功能拓…

桌面/WEB端3D开发工具HOOPS SDK简介

Tech Soft 3D在长达25年的时间内,一直通过卓越的3D技术帮助全球超过600家客户推动创新,这些客户包括HEXAGON、SolidWorks、SIEMENS、Aras、ANSYS、AVEVA等各个行业的领军者。 Tech Soft 3D旗下拥有4款原生产品,分别是:HOOPS Excha…

docker-compose mysql

使用docker-compose 部署 MySQL(所有版本通用) 一、拉取MySQL镜像 我这里使用的是MySQL8.0.18,可以自行选择需要的版本。 docker pull mysql:8.0.18二、创建挂载目录 mkdir -p /data/mysql8/log mkdir -p /data/mysql8/data mkdir -p /dat…

浏览器工作原理与实践--块级作用域:var缺陷以及为什么要引入let和const

在前面《07 | 变量提升:JavaScript代码是按顺序执行的吗?》这篇文章中,我们已经讲解了JavaScript中变量提升的相关内容,正是由于JavaScript存在变量提升这种特性,从而导致了很多与直觉不符的代码,这也是Jav…

【YOLOv5改进系列(8)】高效涨点----添加yolov7中Aux head 辅助训练头

文章目录 🚀🚀🚀前言一、1️⃣ Auxiliary head辅助头简单介绍二、2️⃣从损失函数和标签分配分析三、3️⃣正负样本标签分配四、4️⃣如何添加Aux head辅助训练头五、5️⃣实验部分(后续添加,还是跑模型,辅助头真是太慢…

Chrome 插件各模块使用 Fetch 进行接口请求

Chrome 插件各模块使用 Fetch 进行接口请求 常规网页可以使用 fetch() 或 XMLHttpRequest API 从远程服务器发送和接收数据,但受到同源政策的限制。 内容脚本会代表已注入内容脚本的网页源发起请求,因此内容脚本也受同源政策的约束,插件的来…

Arduino IDE导出esp8266工程编译后的bin文件

一、导出bin文件的方法一 1.通过IDE直接导出,选择 项目 --> 导出已编译的二进制文件,会在工程下生成 build 文件夹,里面有导出的bin文件。 一、导出bin文件的方法二 通过临时文件,找到生成的bin文件。 临时文件的位置&#…

MES系统怎么解决车间生产调度难的问题?

MES系统三个层次 1、MES决定了生产什么,何时生产,也就是说它使公司保证按照订单规定日期交付准确的产品; 2、MES决定谁通过什么方式(流程)生产,即通过优化资源配置,最有效运用资源; …

1500㎡全新展厅升级 无锡冠珠瓷砖旗舰店举行盛大开业典礼

3月23日,无锡冠珠旗舰店重装升级,举行盛大的开业典礼!截止到当天18时,本次开业活动共计成交近300单,收款超300万。新明珠集团董事兼常务副总裁梁旺娟、新明珠集团副总裁兼营销管理中心总经理邓勇、新明珠集团副总经理兼…

翻译 《The Old New Thing》 - Why is a registry file called a “hive“?

Why is a registry file called a “hive“?https://devblogs.microsoft.com/oldnewthing/20030808-00/?p42943 为什么注册表文件被称为‘蜂巢’? Raymond Chen 2003年8月8日 分享一个没用的知识: 话说有一位 Windows NT 的开发者十分讨厌蜜蜂。于是&a…

华清远见STM32U5开发板助力2024嵌入式大赛ST赛道智能可穿戴设备及IOT选题项目开发

第七届(2024)全国大学生嵌入式芯片与系统设计竞赛(以下简称“大赛”)已经拉开帷幕,大赛的报名热潮正席卷而来,高校电子电气类相关专业(电子、信息、计算机、自动化、电气、仪科等)全…

【chemistry 5】糖化学、脂化学和糖代谢

🌞欢迎来到生物化学的世界 🌈博客主页:卿云阁 💌欢迎关注🎉点赞👍收藏⭐️留言📝 🌟本文由卿云阁原创! 📆首发时间:🌹2024年3月29日&…

13 Games101 - 笔记 - 光线追踪(Whitted-Style光线追踪原理详解及实现细节)

13 光线追踪(Whitted-Style光线追踪原理详解及实现细节) 引入光线追踪的原因 光栅化的缺点:不能很好的处理全局光照。(因为Blinn-Phong这种局部模型无法处理全局效果!) 光栅化:快 real-time 质量低光线追…

LeetCode:718最长重复子数组 C语言

718. 最长重复子数组 提示 给两个整数数组 nums1 和 nums2 ,返回 两个数组中 公共的 、长度最长的子数组的长度 。 示例 1: 输入:nums1 [1,2,3,2,1], nums2 [3,2,1,4,7] 输出:3 解释:长度最长的公共子数组是 [3,…

【产品经理】华为IPD需求管理全思路分享!

作为一名产品经理,会在日常工作中接收到各种需求,而解决需求要提供对应的解决方案。本篇文章以华为的IPD需求管理流程为例,探讨其需求管理思路,帮助产品岗位的你快速做好需求管理并解决方案。 一、理清什么是产品需求 说到这个话…