Transformer详解:从放弃到入门(完结)

news2024/12/27 13:59:50

  前几篇文章中,我们已经拆开并讲解了Transformer中的各个组件。现在我们尝试使用这些方法实现Transformer的编码器。

相关文章:
Transformer详解:从放弃到入门(一)
Transformer详解:从放弃到入门(二)
Transformer详解:从放弃到入门(三)

在这里插入图片描述  如图所示,编码器(Encoder)由N个编码器块(Encoder Block)堆叠而成,我们依次实现。

class EncoderBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        dropout: float,
        norm_first: bool = False,
    ) -> None:
        super().__init__()

        self.norm_first = norm_first

        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm1 = LayerNorm(d_model)

        self.ff = PositionWiseFeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    # self attention sub layer
    def _sa_sub_layer(
        self, x: Tensor, attn_mask: Tensor, keep_attentions: bool
    ) -> Tensor:
        x = self.attention(x, x, x, attn_mask, keep_attentions)
        return self.dropout1(x)

    def _ff_sub_layer(self, x: Tensor) -> Tensor:
        x = self.ff(x)
        return self.dropout2(x)

    def forward(
        self, src: Tensor, src_mask: Tensor = None, keep_attentions: bool = False
    ) -> Tuple[Tensor, Tensor]:
        # pass througth multi-head attention
        # src (batch_size, seq_length, d_model)
        # attn_score (batch_size, n_heads, seq_length, k_length)
        x = src
        if self.norm_first:
            x = x + self._sa_sub_layer(self.norm1(x), src_mask, keep_attentions)
            x = x + self._ff_sub_layer(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_sub_layer(x, src_mask, keep_attentions))
            x = self.norm2(x + self._ff_sub_layer(x))

        return x

  需要注意的是,层归一化的位置通过参数norm_first控制,默认norm_first=False,这种实现方式称为Post-LN,是Transformer的默认做法。但这种方式很难从零开始训练,把层归一化放到残差块之间,接近输出层的参数的梯度往往较大。然后在那些梯度上使用较大的学习率会使得训练不稳定。通常需要用到学习率预热(warm-up)技巧,在训练开始时学习率需要设成一个极小的值,但是一旦训练好之后的效果要优于Pre-LN的方式。而如果采用norm_first=True的方式,被称为Pre-LN,它的区别在于对于子层(*_sub_layer)的输入先进行层归一化,再输入到子层中。最后进行残差连接。
在这里插入图片描述  即实际上由上图左变成了图右,注意最后在每个Encoder或Decoder的输出上再接了一个层归一化。
  有了编码器块,我们再来实现编码器。

class Encoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layers: int,
        n_heads: int,
        d_ff: int,
        dropout: float = 0.1,
        norm_first: bool = False,
    ) -> None:
        super().__init__()
        # stack n_layers encoder blocks
        self.layers = nn.ModuleList(
            [
                EncoderBlock(d_model, n_heads, d_ff, dropout, norm_first)
                for _ in range(n_layers)
            ]
        )

        self.norm = LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(
        self, src: Tensor, src_mask: Tensor = None, keep_attentions: bool = False
    ) -> Tensor:
        x = src
        # pass through each layer
        for layer in self.layers:
            x = layer(x, src_mask, keep_attentions)

        return self.norm(x)

  这里要注意的是,最后对编码器和输出进行一次层归一化。至此,我们的编码器完成了,在其forward()中src是词嵌入加上位置编码,那么src_mask是什么?它是用来指示非填充标记的。我们知道,对于文本序列批数据,一个批次内序列长短不一,因此需要以一个指定的最长序列进行填充,而我们的注意力不需要在这些填充标记上进行。
  创建src_mask很简单,假设输入是填充后的批数据:

def make_src_mask(src: Tensor, pad_idx: int = 0) -> Tensor:
    src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask

  输出维度变成(batch_size, 1, 1, seq_length)为了与缩放点积注意力分数适配维度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1657308.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【爬虫基础1.1课】——requests模块

目录索引 requests模块的作用:实例引入: 特殊情况:锦囊1:锦囊2: 这一个栏目,我会给出我从零开始学习爬虫的全过程。感兴趣的小伙伴可以关注一波,用于复习和新学都是不错的选择。 那么废话不多说&#xff0c…

AJAX知识点(前后端交互技术)

原生AJAX AJAX全称为Asynchronous JavaScript And XML,就是异步的JS和XML,通过AJAX可以在浏览器中向服务器发送异步请求,最大的优势:无需刷新就可获取数据。 AJAX不是新的编程语言,而是一种将现有的标准组合在一起使用的新方式 …

中小学校活动向媒体投稿报道宣传有哪些好方法

作为一所中小学校的教师,我肩负着向外界展示学校风采、宣传校园文化活动的重要使命。起初,每当学校举办特色活动或取得教学成果时,我都会满怀热情地撰写新闻稿,希望通过媒体的平台让更多人了解我们的故事。然而,理想丰满,现实骨感,我很快发现,通过电子邮件向媒体投稿的过程充满…

ICode国际青少年编程竞赛- Python-1级训练场-变量的计算

ICode国际青少年编程竞赛- Python-1级训练场-变量的计算 1、 a 2 for i in range(4):Spaceship.step(a-1)Dev.step(a)Dev.step(-a)a a 12、 a 2 for i in range(4):Dev.step(2 a)Dev.step(-a)Dev.turnRight()a a 13、 y 4 for i in range(3):Dev.step(y)Dev.turnRigh…

计算方法实验9:Romberg积分求解速度、位移

任务 输出质点的轨迹 ( x ( t ) , y ( t ) ) , t ∈ { 0.1 , 0.2 , 0.3 , . . . , 10 } (x(t), y(t)), t\in \{0.1, 0.2, 0.3, ..., 10\} (x(t),y(t)),t∈{0.1,0.2,0.3,...,10},并在二维平面中画出该轨迹.请比较M分别取4, 8, 12, 16, 20 时,Romberg积分达…

去除视频背景音乐或人物声音的4种方法,建议收藏

你是否曾想移除视频中令人分心的声音呢?对于需要裁剪声音或去除背景噪音的视频来说,消音是一种非常有用的技能。那么,视频怎么消除声音?看看下文就知道了。 方法一:使用 智优影 去除视频中的音频 在线转换工具不仅支持…

Python轻量级Web框架Flask(13)—— Flask个人博客项目

0、前言: ★这部分内容是基于之前Flask学习内容的一个实战项目梳理内容,没有可以直接抄下来跑的代码,是学习了之前Flask基础知识之后,再来看这部分内容,就会对Flask项目开发流程有更清楚的认知,对一些开发细节可以进一步的学习。项目功能,通过Flask制作个人博客。项目架…

又一个限时免费生成图片的AI平台

网址 https://jimeng.jianying.com/ai-tool/image/generate 抖音官方的文升图,用抖音登录就可以,每天送60积分,目前看文生图好像是限时免费。 随手试了一下,速度很快,质量也还可以,背靠大厂,…

福建 | 福建铭发用行动诠释“敢为天下先”的泉州精神

福建铭发 泉州TOP级企业 在福建,提到混凝土搅拌站,铭发是绕不开的一个存在。 他们是当地最早一批建成的商砼企业,也是如今发展规模最大的TOP级企业。 从2007年建站至今,近15年的发展,他们形成了一套铭发特色的企业经…

【qt】容器的用法

容器目录 一.QVertor1.应用场景2.增加数据3.删除数据4.修改数据5.查询数据6.是否包含7.数据个数8.交换数据9.移动数据10.嵌套使用 二.QList1.应用场景2.QStringList 三.QLinkedList1.应用场景2.特殊点3.用迭代器来变量 四.QStack1.应用场景2.基本用法 五.QQueue1.应用场景2.基本…

WPF之改变任务栏图标及预览

1&#xff0c;略缩图添加略缩按钮。 <Window.TaskbarItemInfo><TaskbarItemInfo x:Name"taskInfo" ProgressState"None" ProgressValue"0.6" ><TaskbarItemInfo.ThumbButtonInfos><ThumbButtonInfo x:Name"btiPlay&q…

分享10个高质量宝藏网站~

分享一波高质量宝藏网站~ 这10个宝藏网站&#xff0c;个个都好用到爆&#xff0c;娱乐、办公、学习都能在这里找到&#xff01; 1、Z-Library https://zh.zlibrary-be.se/ 世界最大的免费电子书下载网站&#xff01;电子书资源超千万&#xff0c;不过这个网站不太稳定&#…

[Kubernetes] KubeKey 部署 K8s v1.28.8

文章目录 1.K8s 部署方式2.操作系统基础配置3.安装部署 K8s4.验证 K8s 集群5.部署测试资源 1.K8s 部署方式 kubeadm: kubekey, sealos, kubespray二进制: kubeaszrancher 2.操作系统基础配置 主机名内网IP外网IPmaster192.168.66.2139.198.9.7node1192.168.66.3139.198.40.17…

Java_File

介绍&#xff1a; File对象表示路径&#xff0c;可以是文件&#xff0c;也可以是文件夹。这个路径可以是存在的&#xff0c;也可以是不存在的&#xff0c;带盘符的路径是绝对路径&#xff0c;不带盘符的路径是相对路径&#xff0c;相对路径默认到当前项目下去找 构造方法&…

MGRE 实验

需求&#xff1a;1、R2为ISP&#xff0c;其上只能配置IP地址。 2、R1-R2之间为HDLC封装 3、R2-R3之间为ppp封装&#xff0c;pap认证&#xff0c;R2为主认证方。 4、R2-R4之间为ppp封装&#xff0c;chap认证&#xff0c;R2为主认证方。 5、R1、R2、R3构建MGRE环境&#xff0…

汽车之家,如何在“以旧换新”浪潮中大展拳脚?

北京车展刚刚落幕&#xff0c;两重利好正主导汽车市场持续升温&#xff1a;新能源渗透率首破50%&#xff0c;以及以旧换新详细政策进入落地期。 图源&#xff1a;中国政府网 在政策的有力指引下&#xff0c;汽车产业链的各个环节正经历着一场深刻的“连锁反应”。在以旧换新的…

STM32——GPIO输出(点亮第一个LED灯)

代码示例&#xff1a; #include "stm32f10x.h" // Device headerint main() {RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOA,ENABLE);//开启时钟GPIO_InitTypeDef GPIO_InitStructure;GPIO_InitStructure.GPIO_Mode GPIO_Mode_Out_PP;GPIO_InitSt…

【JAVA基础之Stream流】掌握Stream流

&#x1f525;作者主页&#xff1a;小林同学的学习笔录 &#x1f525;mysql专栏&#xff1a;小林同学的专栏 目录 1.Stream流 1.1 概述 2.2 常见方法 2.2.1 实现步骤 2.2.2 获取Stream的方式 2.2.3 中间方法 2.2.4 终结方法 1.Stream流 1.1 概述 Stream流可以让你通过链…

Spring AOP(2)

目录 Spring AOP详解 PointCut 切面优先级Order 切点表达式 execution表达式 切点表达式示例 annotation 自定义注解MyAspect 切面类 添加自定义注解 Spring AOP详解 PointCut 上面代码存在一个问题, 就是对于excution(* com.example.demo.controller.*.*(..))的大量重…

Cmake编译源代码生成库文件以及使用

在项目实战中&#xff0c;通过模块化设计能够使整个工程更加简洁明了。简单的示例如下&#xff1a; 1、项目结构 project_folder/├── CMakeLists.txt├── src/│ ├── my_library.cpp│ └── my_library.h└── app/└── main.cpp2、CMakeList文件 # CMake …