【变形金刚03】使用 Pytorch 开始构建transformer

news2024/11/25 10:29:05

 

一、说明

        在本教程中,我们将使用 PyTorch 从头开始构建一个基本的转换器模型。Vaswani等人在论文“注意力是你所需要的一切”中引入的Transformer模型是一种深度学习架构,专为序列到序列任务而设计,例如机器翻译和文本摘要。它基于自我注意机制,已成为许多最先进的自然语言处理模型的基础,如GPT和BERT。

二、准备活动

        若要生成转换器模型,我们将按照以下步骤操作:

  1. 导入必要的库和模块
  2. 定义基本构建块:多头注意力、位置前馈网络、位置编码
  3. 构建编码器和解码器层
  4. 组合编码器和解码器层以创建完整的转换器模型
  5. 准备示例数据
  6. 训练模型

        让我们从导入必要的库和模块开始。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

现在,我们将定义转换器模型的基本构建基块。

三、多头注意力

图2.多头注意力(来源:作者创建的图像)

        多头注意力机制计算序列中每对位置之间的注意力。它由多个“注意头”组成,用于捕获输入序列的不同方面。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

        MultiHeadAttention 代码使用输入参数和线性变换层初始化模块。它计算注意力分数,将输入张量重塑为多个头部,并将所有头部的注意力输出组合在一起。前向方法计算多头自我注意,允许模型专注于输入序列的某些不同方面。

四、位置前馈网络

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

PositionWiseFeedForward 类扩展了 PyTorch 的 nn。模块并实现按位置的前馈网络。该类使用两个线性转换层和一个 ReLU 激活函数进行初始化。forward 方法按顺序应用这些转换和激活函数来计算输出。此过程使模型能够在进行预测时考虑输入元素的位置。

五、位置编码

        位置编码用于注入输入序列中每个令牌的位置信息。它使用不同频率的正弦和余弦函数来生成位置编码。

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

PositionalEncoding 类使用 d_model 和 max_seq_length 输入参数进行初始化,从而创建一个张量来存储位置编码值。该类根据比例因子div_term分别计算偶数和奇数指数的正弦和余弦值。前向方法通过将存储的位置编码值添加到输入张量中来计算位置编码,从而使模型能够捕获输入序列的位置信息。

现在,我们将构建编码器层和解码器层。

六、编码器层

图3.变压器网络的编码器部分(来源:图片来自原文)

编码器层由多头注意力层、位置前馈层和两个层归一化层组成。

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

类使用输入参数和组件进行初始化,包括一个多头注意模块、一个 PositionWiseFeedForward 模块、两个层规范化模块和一个 dropout 层。前向方法通过应用自注意、将注意力输出添加到输入张量并规范化结果来计算编码器层输出。然后,它计算按位置的前馈输出,将其与归一化的自我注意输出相结合,并在返回处理后的张量之前对最终结果进行归一化。

七、解码器层

图4.变压器网络的解码器部分(Souce:图片来自原始论文)

解码器层由两个多头注意力层、一个位置前馈层和三个层归一化层组成。

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

解码器层使用输入参数和组件进行初始化,例如用于屏蔽自我注意和交叉注意力的多头注意模块、PositionWiseFeedForward 模块、三层归一化模块和辍学层。

转发方法通过执行以下步骤来计算解码器层输出:

  1. 计算掩蔽的自我注意输出并将其添加到输入张量中,然后进行 dropout 和层归一化。
  2. 计算解码器和编码器输出之间的交叉注意力输出,并将其添加到规范化的掩码自注意力输出中,然后进行 dropout 和层规范化。
  3. 计算按位置的前馈输出,并将其与归一化交叉注意力输出相结合,然后是压差和层归一化。
  4. 返回已处理的张量。

这些操作使解码器能够根据输入和编码器输出生成目标序列。

现在,让我们组合编码器和解码器层来创建完整的转换器模型。

八、变压器型号

图5.The Transformer Network(来源:图片来源于原文)

将它们全部合并在一起:

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

类组合了以前定义的模块以创建完整的转换器模型。在初始化期间,Transformer 模块设置输入参数并初始化各种组件,包括源序列和目标序列的嵌入层、位置编码模块、用于创建堆叠层的编码器层和解码器层模块、用于投影解码器输出的线性层以及 dropout 层。

generate_mask 方法为源序列和目标序列创建二进制掩码,以忽略填充标记并防止解码器处理将来的令牌。前向方法通过以下步骤计算转换器模型的输出:

  1. 使用 generate_mask 方法生成源掩码和目标掩码。
  2. 计算源和目标嵌入,并应用位置编码和丢弃。
  3. 通过编码器层处理源序列,更新enc_output张量。
  4. 通过解码器层处理目标序列,使用enc_output和掩码,并更新dec_output张量。
  5. 将线性投影层应用于解码器输出,获取输出对数。

这些步骤使转换器模型能够处理输入序列,并根据其组件的组合功能生成输出序列。

九、准备样本数据

        在此示例中,我们将创建一个用于演示目的的玩具数据集。实际上,您将使用更大的数据集,预处理文本,并为源语言和目标语言创建词汇映射。

src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

十、训练模型

        现在,我们将使用示例数据训练模型。在实践中,您将使用更大的数据集并将其拆分为训练集和验证集。

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

        我们可以使用这种方式在 Pytorch 中从头开始构建一个简单的转换器。所有大型语言模型都使用这些转换器编码器或解码器块进行训练。因此,了解启动这一切的网络非常重要。希望本文能帮助所有希望深入了解LLM的人。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/868090.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【传输层】Tcp协议的原理(二)

文章目录 一、TCP协议原理(二)总结 一、TCP协议原理(二) 1.解决TIME_WAIT状态引起的bind失败的方法 我们之前实现tcp服务器的时候发现,服务器经常有时候断开可以立即重启,有时候断开必须换端口号才能重新…

qt事件系统源码-----定时器

qt定时器的使用一般有以下几种方式: 1、直接使用QTimer对象,绑定定时器的timeout信号; 2、使用QTimer的静态方法singleshot方法,产生一个一次性的定时事件 3、在QObject子类中,调用startTimer方法,产生定…

Vue.js从入门到精通:软件开发视频大讲堂

💂 个人网站:【工具大全】【游戏大全】【神级源码资源网】🤟 前端学习课程:👉【28个案例趣学前端】【400个JS面试题】💅 寻找学习交流、摸鱼划水的小伙伴,请点击【摸鱼学习交流群】 前言 随着Web应用程序的…

nodejs+vue+elementui+express旅游出行指南网站_655ms

本文主要介绍了一种基于windows平台实现的旅游出行指南。该系统为用户找到景点信息和酒店信息提供了更安全、更高效、更便捷的途径。本系统有两个角色:管理员和用户,要求具备以下功能: (1)用户可以浏览主页了解旅游出行…

2023牛客暑期多校训练营8-C Clamped Sequence II

2023牛客暑期多校训练营8-C Clamped Sequence II https://ac.nowcoder.com/acm/contest/57362/C 文章目录 2023牛客暑期多校训练营8-C Clamped Sequence II题意解题思路代码 题意 解题思路 先考虑不加紧密度的情况,要支持单点修改,整体查询&#xff0…

阿里云服务器安装AMH面板建站教程

本文阿里云百科分享使用阿里云服务器安装AMH面板建站教程,AMH是一套通过Web控制和管理Linux服务器以及虚拟主机的管理系统。您可以使用云服务器ECS安装AMH来搭建PHP网站。本篇教程分别介绍如何在Linux系统实例中部署AMH并快速搭建PHP网站。 目录 前提条件 手动部…

Electron 应用实现截图并编辑功能

Electron 应用实现截图并编辑功能 Electron 应用如何实现截屏功能,有两种思路,作为一个框架是否可以通过框架实现截屏,另一种就是 javaScript 结合 html 中画布功能实现截屏。 在初步思考之后,本文优先探索使用 Electron 实现截屏…

CCF-CSP 29次 第三题【202303-3 LDAP】(多个STL+递归)

计算机软件能力认证考试系统 #include <iostream> #include <cstring> #include <algorithm> #include <vector> #include <unordered_map> #include <string>using namespace std;typedef long long LL;const int N 2510, M 510;int n…

Qt应用开发(基础篇)——拆分器窗口 QSplitter QSplitterHandle

一、前言 QSplitter继承于QFrame&#xff0c;QFrame继承于QWidget&#xff0c;是Qt的一个部件容器工具类。 框架类QFrame介绍 QSplitter拆分器&#xff0c;用户通过拖动子部件之间的边界来控制子部件的大小&#xff0c;在应用开发中数据分模块展示、图片展示等场景下使用。 二、…

Nuxt.js--》解锁 Nuxt 项目的潜力:从配置开始,迈向成功

博主今天开设Nuxt.js专栏&#xff0c;带您深入探索 Nuxt.js 的精髓&#xff0c;学习如何利用其强大功能构建出色的前端应用程序。我们将探讨其核心特点、灵活的路由系统、优化技巧以及常见问题的解决方案。无论您是想了解 Nuxt.js 的基础知识&#xff0c;还是希望掌握进阶技巧&…

FastDFS安装教程

FastDFS安装 软件下载 需要的软件&#xff1a;fastdfs-6.0.4、libfastcommon-1.0.42、fastdfs-nginx-module-1.22.tar.gz 下载地址 安装 fastdfs是使用c语言写的&#xff0c;需要先配置c语言环境。 yum install -y gcc gcc-cyum install libevent安装libfastcommon函数库…

Leetcode-每日一题【剑指 Offer 20. 表示数值的字符串】

题目 请实现一个函数用来判断字符串是否表示数值&#xff08;包括整数和小数&#xff09;。 数值&#xff08;按顺序&#xff09;可以分成以下几个部分&#xff1a; 若干空格一个 小数 或者 整数&#xff08;可选&#xff09;一个 e 或 E &#xff0c;后面跟着一个 整数若干空…

Win7累积补丁更新包_UpdatePack7R2-23.8.10

UpdatePack7是最新的Win7补丁累积更新包&#xff0c;Windows 7更新补丁安装包&#xff0c;Win7累积更新离线安装包包括所有关键更新和安全更新及Internet Explorer所有版本的更新&#xff0c;此外还集成了NVMe驱动和USB3.0驱动&#xff0c;使用它还可以将累积更新封装到系统内&…

通过BitMap实现签到

针对黑马点评。 bitmap签到 在传统的签到系统中的数据库的表一般都采取直接存储的形式&#xff0c;类似于一种记录表&#xff0c;但是如果用户的数量特别大&#xff0c;签到上几个月之后&#xff0c;这种表的数据量特别大&#xff0c;同时&#xff0c;存储的数据也会占用很多…

大汇总!各省杰青优青名单已出炉

【SciencePub学术】杰青&#xff0c;是国家杰出青年基金项目资助获得者的简称&#xff0c;与科技奖励计划类似&#xff0c;是我国重要的人才计划之一。一所学校的杰青数量&#xff0c;代表学校未来的学术发展潜力和在同类高校的学术地位&#xff0c;每所大学都非常看重。今年部…

Lombok的使用及注解含义

文章目录 一、简介二、如何使用2.1、在IDEA中安装Lombok插件2.2、添加maven依赖 三、常用注解3.1、Getter / Setter3.2、ToString3.3、NoArgsConstructor / AllArgsConstructor3.4、EqualsAndHashCode3.5、Data3.6、Value3.7、Accessors3.7.1、Accessors(chain true)3.7.2、Ac…

C++笔记之if(指针)的含义

C笔记之if(指针)的含义 code review! 文章目录 C笔记之if(指针)的含义例1例2 例1 例2

突然断电CAD图纸没保存怎么恢复?

CAD图纸绘制时&#xff0c;有时会遇到一些意外情况&#xff0c;比如突然断电、电脑意外关机或者软件异常退出&#xff0c;但是图纸还没保存&#xff0c;这该怎么办&#xff1f;这对于设计师来说&#xff0c;简直要崩溃&#xff0c;不仅白干一天&#xff0c;还得加班赶进度&…

【果树农药喷洒机器人】Part6:基于深度相机与分割掩膜的果树冠层体积探测方法

文章目录 一、引言二、树冠体积测量对比方法2.1冠层体积人工测量法2.2冠层体积拟合测量法 三、基于深度相机与分割掩膜探测树冠体积方法3.1像素值与深度值的转换3.2树冠体积视觉探测法3.3实验分析 总结 一、引言 果树靶标探测是实现农药精准喷施的关键环节&#xff0c;本章以果…

2. Hello World

Hello World 我们将用 Java 编写两个程序。发送单个消息的生产者和接收消息并打印 出来的消费者。我们将介绍 Java API 中的一些细节。 在下图中&#xff0c;“ P”是我们的生产者&#xff0c;“ C”是我们的消费者。中间的框是一个队列-RabbitMQ 代 表使用者保留的消息缓冲区 …