PyTorch从零开始实现Transformer

news2024/11/25 0:28:54

文章目录

    • 自注意力
    • Transformer块
    • 编码器
    • 解码器块
    • 解码器
    • 整个Transformer
    • 参考来源
    • 全部代码(可直接运行)

自注意力

计算公式

在这里插入图片描述

代码实现


class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0] # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # 矩阵乘法,使用爱因斯坦标记法
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20")) 
            #Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        ) # 矩阵乘法,使用爱因斯坦标记法einsum
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

Transformer块

我们把Transfomer块定义为如下图所示的结构,这个Transformer块在编码器和解码器中都有出现过。
在这里插入图片描述

代码实现

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

编码器

编码器结构如下所示,Inputs经过Input Embedding 和Positional Encoding之后,通过多个Transformer块

在这里插入图片描述

代码实现

class Encoder(nn.Module):
    def __init__(self, 
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
            for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

解码器块

解码器块结构如下图所示

在这里插入图片描述

代码实现

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out

解码器

解码器块加上word embedding 和 positional embedding之后构成解码器

在这里插入图片描述

代码实现

class Decoder(nn.Module):
    def __init__(self, 
                 trg_vocab_size, 
                 embed_size, 
                 num_layers, 
                 heads, 
                 forward_expansion, 
                 dropout, 
                 device, 
                 max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
             for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

整个Transformer

在这里插入图片描述

代码实现

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size, 
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=256,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device="cuda",
                 max_length=100
                 ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask,  trg_mask)
        return out
    

参考来源

[1] https://www.youtube.com/watch?v=U0s0f995w14
[2] https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py

[3] https://arxiv.org/abs/1706.03762
[4] https://www.youtube.com/watch?v=pkVwUVEHmfI

全部代码(可直接运行)

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0] # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20")) 
            #Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out
    
class Encoder(nn.Module):
    def __init__(self, 
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
            for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out
    
class Decoder(nn.Module):
    def __init__(self, 
                 trg_vocab_size, 
                 embed_size, 
                 num_layers, 
                 heads, 
                 forward_expansion, 
                 dropout, 
                 device, 
                 max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
             for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size, 
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=256,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device="cuda",
                 max_length=100
                 ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask,  trg_mask)
        return out
    
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
        device
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
    out = model(x, trg[:, :-1])
    print(out.shape)



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/752312.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

奥特曼与钢铁侠【InsCode Stable Diffusion美图活动一期】

文章目录 简介图片生成步骤更多体验方式 简介 InsCode 是一个一站式的软件开发服务平台,从开发-部署-运维-运营,都可以在 InsCode 轻松完成。 InsCode 的 Ins 是 Inspiration,意思是创作、寻找有灵感的代码。 Stable Diffusion是文图生成模型…

MySQL的下载、安装和配置(图文详解)

目录 一、MySQL的4大版本 二、软件的下载 1. 下载地址 2. 打开官网,点击DOWNLOADS 3. 点击 MySQL Community Server 三、MySQL8.0 版本的安装 四、配置MySQL8.0 五、配置MySQL8.0 环境变量 六、MySQL5.7 版本的安装、配置 一、MySQL的4大版本 MySQL Commu…

内存泄露?腾讯工程师2个压箱底的方法和工具

导读|遭受内存泄露往往是令开发者头疼的问题,传统分析工具 gdb、Valgrind在解决内存泄露问题上效率较低。本文特别邀请到了腾讯后台开发工程师邢孟棒以 TDSQL实际生产中mysql-proxy内存泄露问题作为分析对象,分享其基于动态追踪技术的通用内存…

倾斜摄影三维模型数据在哪些行业或场景发挥了重要的作用?

倾斜摄影技术是一种高精度三维建模方法,目前已经在许多行业中获得了认可。倾斜摄影技术就是利用无人机搭载相机以不同的角度捕捉建筑物、道路等物体的外形,并生成高分辨率的三维模型数据。那么,这些数据能够应用在哪些行业或场景中呢&#xf…

智慧校园能源管控系统

智慧校园能源管控系统是一种搭载了物联网技术、大数据技术、大数据等技术性智能化能源管理方法系统,致力于为学校提供更高效、安全性、可信赖的能源供应管理和服务。该系统包括了校内的电力工程、水、气、暖等各类能源,根据对能源的实时检测、数据统计分…

最全的ubuntu20.04上安装nvidia和cuda

文章目录 一、环境要求二、查询推荐安装的驱动版本三、安装470四、查看显卡驱动是否成功五、安装对应版本的cuda 官方路径 一、环境要求 ubuntu20.04如果之前有过驱动应该先卸载 sudo apt-get --purge remove nvidia* sudo apt autoremove# 卸载cuda sudo apt-get --purge r…

Jenkins全栈体系(二)

Jenkins 第三章 Jenkins Git Maven 自动化部署配置 十、几种构建方式 快照依赖构建/Build whenever a SNAPSHOT dependency is built 当依赖的快照被构建时执行本job 触发远程构建 (例如,使用脚本) 远程调用本job的restapi时执行本job job依赖构建/Build after other proj…

【分布式应用】kafka集群、Filebeat+Kafka+ELK搭建

目录 一、kafka概述1.1为什么需要消息队列(MQ)1.2常见的中间1.3消息队列的优点1.4消息队列的两种模式1.5 Kafka 定义1.6 Kafka 的特性1.7kafka的系统架构 二、部署kafka集群2.1安装kafka2.2Kafka 命令行操作 三、kafka架构深入3.1kfka工作流程及文件存储…

flutter开发实战-flutter实现类似iOS的Alert提示框与sheet菜单效果

flutter开发实战-flutter实现类似iOS的Alert提示框与sheet菜单效果 在开发过程中,经常使用到提示框Dialog,与sheet,使用到了flutter的showDialog与showModalBottomSheet 我这里类似alert弹窗直接调用 flutter 中提供的showDialog()函数显示对…

QML语法--第二篇

链接: QML Book中文版(QML Book In Chinese) QML语言描述了用户界面元素的形状和行为。用户界面能够使用JavaScript来提供修饰,或者增加更加复杂的逻辑。 从QML元素的层次结构来理解是最简单的学习方式。子元素从父元素上继承了坐标系统,它的x,y坐标总…

社交圈..

社交圈 - 题目 - Daimayuan Online Judge 思路:我们能够想到,如果i,j是并列的,则l[i]与r[j]会有重合的部分,那么其实重合的部分越多越好,其实就是让l[i]与r[i]的差值越小越好,同时要让越…

使用qemu创建ubuntu-base文件系统,并安装PM相关内核模块

目录 一、配置镜像二、使用qemu模拟nvdimm(安装PM相关内核模块)运行记录 遇到的一些问题1、ext4文件系统损坏问题:系统启动时,遇到ext4的报错信息解决办法:2、内核模块未成功加载3、qemu报错4、主机终端无法正常打开 流…

一、对象的概念(2)

本章概要 复用继承 “是一个”与“像是一个”的关系 多态 复用 一个类经创建和测试后,理应是可复用的。然而很多时候,由于程序员没有足够的编程经验和远见,我们的代码复用性并不强。 代码和设计方案的复用性是面向对象程序设计的优点之一…

MacOS Sonoma 14.0 (23A5286i) Beta3 带 OC 0.9.3 and FirPE 双分区原版黑苹果镜像

7 月 12 日苹果发布了 macOS Sonoma 14.0 Beta3 第二个版本(内部版本号:23A5286i)和 tvOS 17 Beta3(内部版本号: 21J5303h),同时还面向iOS / iPadOS 17 更新发布 Beta3 更新。 一、镜像下载&…

力扣每日一题2023.7.14在二叉树中分配硬币

题目: 示例: 分析: 给一个二叉树,二叉树的值总和等于整个二叉树的节点数,我们一次可以移动一个节点的一个值到相邻节点,问我们需要移动多少次才可以把值的总和平均分配每个节点(即每个节点的值都为1). 我们可以确定的是,如果某个节点的左子树共有n个节点,而节点值总和为m,那么…

【2023,学点儿新Java-35】强制类型转换的执行时机、风险与安全 | 基本类型转换、引用类型转换

前情提要: 【2023,学点儿新Java-34】基本数据类型变量 运算规则:自动类型提升、强制类型转换 | 为什么标识符的声明规则里要求不能数字开头?(通俗地讲解——让你豁然开朗!)【2023,学…

嵌入式Linux开发实操(五):embedded linux嵌入式Linux开发

前言: embedded linux开发有个好处就是开源的,总的来说涉及五个部分: 1、工具链Toolchain:为目标设备创建代码需要的编译器和其他工具。其他一切都取决于工具链。 2、引导程序Bootloader:它初始化板并加载Linux kernal。 3、内核kernal:这是系统的core核心,管理系统…

ELK-日志服务【redis-配置使用】

目录 环境 【1】redis配置 【2】filebeat配置 【3】对接logstash配置 【4】验证 【5】安全配置:第一种:kibana-nginx访问控制 【6】第二种:在ES-主节点-配置TLS 【7】kibana配置密码 【8】logstash添加用户密码 环境 es-01,kibana 1…

一文详解什么是数据库分片

概要 应用程序正在变得越来越好,它拥有更多的功能、更多的活跃用户,并且每天都会收集更多的数据。但数据库现在导致应用程序的其余部分变慢。数据库分片可能是问题的答案,但许多人不知道它是什么,最重要的是何时使用它。在本文中我…

Android13 编译错误汇总

1. error: New setting keys are not allowed 一版是在Settings中添加了新的字段导致的 解决: 在你的字段上面加上SuppressLint("NoSettingsProvider") 继续编译应该会出现 按照提示 make api-stubs-docs-non-updatable-update-current-api 然后再…