用 Pytorch 训练一个 Transformer模型

news2024/11/26 12:46:51

昨天说了一下Transformer架构,今天我们来看看怎么 Pytorch 训练一个Transormer模型,真实训练一个模型是个庞大工程,准备数据、准备硬件等等,我只是做一个简单的实现。因为只是做实验,本地用 CPU 也可以运行。
本文包含以下几部分:

  1. 准备环境。
  2. 然后就是跟据架构来定义每一层,包括Embedding、Position Encoding、多头注意力、 网络层。
  3. 准备Encoder。
  4. 准备Decoder。
  5. 运行 Transformer,包括训练和评估。

安装Pytorch 环境

!pip3 install torch torchvision torchaudio

引入所需工具类库

引入需要的类库,pytorch 是强大的训练框架,深度学习中需要的一些函数和基本功能都已经实现。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

Position Embedding

生成位置信息。

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

多头注意力

  1. 初始化model 维度,头数,每个头的维度。
  2. 计算是在 forward 这个方法,主要看这个方法。
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # 转换,每一个 head 独立处理
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # 线性转换并切分
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # 运行计算公式
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # 合并并返回
        output = self.W_o(self.combine_heads(attn_output))
        return output

网络定义

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

Encoder

在这里插入图片描述
跟据这张图看下面的实现比较直观,初始化了MultiHeadAttention、PositionWiseFeedForward、两个LayerNorm。 forward 方法中 x 是 Encoder 的输入。

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

Decoder

在这里插入图片描述
看代码的方式和 Encoder 类似,比较好理解,2 个MultiHeadAttention、3个 Norm,forward 中cross_attn 把 enc_output作为传入的参数。

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

Transformer

Transformer主类,包括初始化 embedding、position embedding、encoder 和 decoder。forward 方法进行计算。

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

训练

首先准备数据,这里的数据是随机生成,只是做演示。

src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

开始训练

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

评估

将模型运行在验证集或者测试集上,这里数据也是随机生成的,只为体验一下完整流程。

transformer.eval()

# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

with torch.no_grad():

    val_output = transformer(val_src_data, val_tgt_data[:, :-1])
    val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")

如果你对 Pytorch 和神经网络比较熟悉,Transformer整体实现起来并不复杂,如果想我一样对深度学习不太熟悉,理解起来还是有些困难,这里只是大概跑了一下流程,对Transformer训练有一个概念。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1625916.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python 脚本头(PyCharm+python头部信息、py头部信息、python头信息、py头信息、py文件头部)

文章目录 参考PyCharm设置脚本头头部信息 参考 https://developer.aliyun.com/article/1166544 https://blog.csdn.net/Dontla/article/details/131743495 https://blog.csdn.net/dongyouyuan/article/details/54408413 PyCharm设置脚本头 打开pycharm,点击file–…

深入了解MySQL:从基础到特性,全面解读关系数据库管理系统的历史与应用

文章目录 1. MySQL简介1.1 概述1.2 架构与兼容性1.3 开源与社区支持 2. MySQL的历史2.1 创始与初衷2.2 发展历程2.3 在Oracle的持续发展2.4 开源与商业结合 3. MySQL的核心特性4. MySQL在实际应用中的作用4.1 网站建设与内容管理4.2 商业智能与客户关系管理4.3 企业级应用与云集…

【大数据】分布式数据库HBase

目录 1.概述 1.1.前言 1.2.数据模型 1.3.列式存储的优势 2.实现原理 2.1.region 2.2.LSM树 2.3.完整读写过程 2.4.master的作用 1.概述 1.1.前言 本文式作者大数据系列专栏中的一篇文章,按照专栏来阅读,循序渐进能更好的理解,专栏…

解锁3D技术文档创建新可能,提升跨部门沟通效率

随着制造业的蓬勃发展与数字化转型的持续推进,产品设计与制作部门间的协作与沟通显得愈发关键。然而,传统沟通方式存在的效率低下、信息传递失准等弊端,已成为制约产品设计、制作效率和质量的瓶颈。 与此同时,CAD软件作为产品设计…

Scala 04 —— Scala Puzzle 拓展

Scala 04 —— Scala Puzzle 拓展 文章目录 Scala 04 —— Scala Puzzle 拓展一、占位符二、模式匹配的变量和常量模式三、继承 成员声明的位置结果初始化顺序分析BMember 类BConstructor 类 四、缺省初始值与重载五、Scala的集合操作和集合类型保持一致性第一部分代码解释第二…

海外云服务对比: AWS、GCP、Azure 与 DigitalOcean

云计算市场持续增长,预计到2030年将达到 2432.87 亿美元。在这个庞大的市场中,三家云服务提供商——亚马逊(AWS)、谷歌云平台(GCP)和微软Azure——共占云市场份额的64%。当用户选择云服务提供商来托管他们的…

【树莓派4B】如何点亮树莓派的LED灯

在之前一系列文章中,使用python、行人入侵检测,确没有使用树莓派的硬件。控制引脚进行输出: 如何写python点亮led灯闪烁,我灯接在gpio13,GPIO19,gpio26。我都想闪烁。 你可以使用Python的GPIO库来控制树莓派上的LED灯。首先&…

一年期免费SSL证书申请方法

免费SSL证书的申请已经成为当今互联网安全实践中的重要环节,它不仅有助于保护网站数据传输的隐私性和完整性,还能提升用户信任度,因为现代浏览器会明确标识出未使用HTTPS(即未部署SSL证书)的网站为“不安全”。以下是一…

(windows ssh) windows开启ssh服务,并通过ssh登录该win主机

☆ 问题描述 想要通过ssh访问win主句 ★ 解决方案 安装ssh服务 打开服务 如果这里开不来就“打开服务”,找到下面两个开启服务 然后可以尝试ssh链接,注意,账号密码,账号是这个: 密码是这个 同理,如果…

看企业中很多老师傅都说没前途,该不该放弃嵌入式单片机行业?

在企业中,我们经常会听到很多老师傅感叹嵌入式单片机行业没有前途,这也让不少人陷入了迷茫,不知道该不该放弃这个行业。其实,我发现很多新手在嵌入式和单片机领域都存在一个误区,那就是他们过于专注于工作技能的提升&a…

适合初学者的自然语言处理 (NLP) 综合指南

一、简述 自然语言处理 (NLP) 是人工智能 (AI) 最热门的领域之一,现在主要指大语言模型了。这要归功于人们热衷于能编写故事的文本生成器、欺骗人们的聊天机器人以及产生照片级真实感的文本到图像程序等应用程序。近年来,计算机理解人类语言、编程语言&a…

从零到一,打造高效团队!项目经理必备的5大核心能力

快速晋升项目经理的绝密成功法则,你掌握了吗? 项目经理通常具备一系列关键的能力和素质,并付诸实践。以下是项目经理快速晋升需要掌握的技能: 一、精通核心技能,奠定项目基础 1、专业知识技能:项目经理…

Linux-进程和计划任务管理⭐

目录 一、程序和进程 1.程序 2.进程 3.线程与进程 二、ps查看静态进程信息 1.ps aux 命令 2.ps-静态查看系统进程 3.ps -elf 三、top-查看进程动态信息 四、pgrep查看进程信息 五、pstree-查看进程树 六、控制进程 1.进程启动方式 2.调度启动 3.进程的前后台调…

电商ERP是什么,电商ERP有什么用?||电商API接口

目录 一. 电商ERP是什么? 二. 电商ERP的常见功能 三. 小结 电商ERP接口接入 01 一. 电商ERP是什么? 随着电商经济的快速发展,电商企业面临着机遇和挑战。企业都希望快速拓展市场,通过多个渠道增加销售额。电商ERP系统是一种先进的应用系统&#…

EaseUS RecExperts for Mac/Win:你的专属屏幕录像专家

在信息爆炸的时代,屏幕录像软件已经成为我们工作和生活中的得力助手。无论是教学演示、产品介绍,还是游戏录制、会议记录,一款功能强大的屏幕录像软件都能轻松应对。而EaseUS RecExperts,正是这样一款值得你信赖的屏幕录像专家。 …

网络安全之弱口令与命令爆破(上篇)(技术进阶)

目录 一,什么是弱口令? 二,为什么会产生弱口令呢? 三,字典的生成 四,使用Burpsuite工具弱口令爆破 总结 一,什么是弱口令? 弱口令就是容易被人们所能猜到的密码呗,…

springboot+thymeleaf实现一个简单的监听在线人数功能

功能步骤: 1. 当用户访问登录页面时,Logincontroller的showLoginForm方法被调用,返回登录页面的视图名字。 2. 用户提交表单,调用LoginController的login方法。 3.login方法 4.登录验证通过,home方法会被调用&#xf…

JavaEE >> Spring Boot(1)

Spring Boot 前面已经介绍了 Spring ,是为了简化 Java 程序开发的,而在前面创建的过程中就会发现其实 Spring 还是有点复杂,此时 Spring Boot 就诞生了, Spring Boot 是为了简化 Spring 程序开发的。 Spring Boot 即 Spring 脚手…

transformer 最简单学习3, 训练文本数据输入的形式

1、输入数据中,源数据和目标数据的定义 def get_batch(source,i):用于获取每个批数据合理大小的源数据和目标数据参数source 是通过batchfy 得到的划分batch个 ,的所有数据,并且转置列表示i第几个batchbptt 15 #超参数,一次输入多少个ba…

【opencv 加速推理】如何安装 支持cuda的opencv 包 用于截帧加速

要在支持CUDA的系统上安装OpenCV,您可以使用pip来安装支持CUDA的OpenCV版本。OpenCV支持CUDA加速,但需要安装额外的库,如cuDNN和NVIDIA CUDA Toolkit。以下是一般步骤: 安装NVIDIA CUDA Toolkit: 首先,您需要安装NVID…