Chapter4.3:Implementing a feed forward network with GELU activations

news2025/1/7 22:28:04

4 Implementing a GPT model from Scratch To Generate Text

4.3 Implementing a feed forward network with GELU activations

  • 本节即将实现子模块,用于transformer block(变换器块)的一部分。为此,我们需要从激活函数开始。

    深度学习中,ReLU因其简单和有效而被广泛使用。但在大语言模型中,还使用了GELU和SwiGLU这两种更复杂、平滑的激活函数,它们结合了高斯和sigmoid门控,提升了模型性能,与ReLU的简单分段线性不同。

  • GELU(Hendrycks 和 Gimpel,2016)可以通过多种方式实现;其精确版本定义为 GELU(x)=x⋅Φ(x),其中 Φ(x) 是标准高斯分布的累积分布函数。

    在实践中,通常会实现一种计算成本更低的近似版本:
    GELU ( x ) ≈ 0.5 ⋅ x ⋅ ( 1 + tanh ⁡ [ 2 π ⋅ ( x + 0.044715 ⋅ x 3 ) ] ) \text{GELU}(x) \approx 0.5 \cdot x \cdot \left(1 + \tanh\left[\sqrt{\frac{2}{\pi}} \cdot \left(x + 0.044715 \cdot x^3\right)\right]\right) GELU(x)0.5x(1+tanh[π2 (x+0.044715x3)])
    (原始的 GPT-2 模型也是使用此近似版本进行训练的)。

    class GELU(nn.Module):
        def __init__(self):
            super().__init__()
    
        def forward(self, x):
            return 0.5 * x * (1 + torch.tanh(
                torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
                (x + 0.044715 * torch.pow(x, 3))
            ))
        
    import matplotlib.pyplot as plt
    
    gelu, relu = GELU(), nn.ReLU()
    
    # Some sample data
    x = torch.linspace(-3, 3, 100)
    y_gelu, y_relu = gelu(x), relu(x)
    
    plt.figure(figsize=(8, 3))
    for i, (y, label) in enumerate(zip([y_gelu, y_relu], ["GELU", "ReLU"]), 1):
        plt.subplot(1, 2, i)
        plt.plot(x, y)
        plt.title(f"{label} activation function")
        plt.xlabel("x")
        plt.ylabel(f"{label}(x)")
        plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    

    image-20241229212711708

    如上图所示

    1. ReLU:分段线性函数,正输入直接输出,负输入输出零。

      ReLU的局限性: 零处有尖锐拐角,可能增加优化难度;对负输入输出零,限制了负输入神经元的作用。

    2. GELU:平滑非线性函数,近似 ReLU,对负值具有非零梯度(除了在约-0.75处之外)

      GELU的优势:平滑性带来更好的优化特性;对负输入输出较小的非零值,使负输入神经元仍能贡献学习过程。

  • 接下来让我们使用 GELU 函数来实现小型神经网络模块 FeedForward,稍后我们将在 LLM 的转换器块中使用它:

    class FeedForward(nn.Module):
        def __init__(self, cfg):
            super().__init__()
            self.layers = nn.Sequential(
                nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
                GELU(),
                nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
            )
    
        def forward(self, x):
            return self.layers(x)
    

    上述的前馈模块是包含两个线性层和一个GELU激活函数的小神经网络,在1.24亿参数的GPT模型中,用于处理嵌入大小为768的令牌批次。

    print(GPT_CONFIG_124M["emb_dim"])
    
    """输出"""
    768
    

    初始化一个新的前馈网络(FeedForward)模块,其 token 嵌入大小为 768,并向其输入一个包含 2 个样本且每个样本有 3 个 token 的批次输入。我们可以看到,输出张量的形状与输入张量的形状相同:

    ffn = FeedForward(GPT_CONFIG_124M)
    
    # input shape: [batch_size, num_token, emb_size]
    x = torch.rand(2, 3, 768) 
    out = ffn(x)
    print(out.shape)
    
    """输出"""
    torch.Size([2, 3, 768])
    

    本节的前馈模块对模型学习和泛化至关重要。它通过内部扩展嵌入维度到更高空间,如下图所示,然后应用GELU激活,最后收缩回原维度,以探索更丰富的表示空间。

    前馈神经网络中层输出的扩展和收缩的图示。首先,输入值从 768 个值扩大到 4 倍,达到 3072 个值。然后,第二层将 3072 个值压缩回 768 维表示。

  • 本节至此,现在已经实现了 下图中LLM 的大部分构建块(打勾的部分)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2272876.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常见加密算法附JAVA代码案例

1、对称加密算法(AES、DES、3DES) 对称加密算法是指加密和解密采用相同的密钥,是可逆的(即可解密)。 AES加密算法是密码学中的高级加密标准,采用的是对称分组密码体制,密钥长度的最少支持为128。AES加密算法是美国联邦政府采用的区块加密标准,这个标准用来替代原先的…

SwiftUI 撸码常见错误 2 例漫谈

概述 在 SwiftUI 日常撸码过程中,头发尚且还算茂盛的小码农们经常会犯这样那样的错误。虽然犯这些错的原因都很简单,但有时想要快速准确的定位它们却并不容易。 况且这些错误还可能在模拟器和 Xcode 预览(Preview)表现的行为不甚…

linux系统(ubuntu,uos等)连接鸿蒙next(mate60)设备

以前在linux上是用adb连接,现在升级 到了鸿蒙next,adb就不好用了。得用Hdc来了,在windows上安装了hisuit用的好好的,但是到了linux(ubuntu2204)下载安装了 下载中心 | 华为开发者联盟-HarmonyOS开发者官网,共建鸿蒙生…

2024年12月HarmonyOS应用开发者高级认证全新题库

注意事项:切记在考试之外的设备上打开题库进行搜索,防止切屏三次考试自动结束,题目是乱序,每次考试,选项的顺序都不同,作者已于2024年12月15日又更新了一波题库,题库正确率99%! 新版…

Anaconda安装(2024最新版)

安装新的anaconda需要卸载干净上一个版本的anaconda,不然可能会在新版本安装过程或者后续使用过程中出错,完全卸载干净anaconda的方法,可以参考我的博客! 第一步:下载anaconda安装包 官网:Anaconda | The O…

docker中使用Volume完成数据共享

情景概述 在一个docker中,部署两个MySQL容器,假如它们的数据都存储在自己容器内部的data目录中。这样的存储方式会有以下问题: 1.无法保证两个MySQL容器中的数据同步。 2.容器删除后,数据就会丢失。 基于以上问题,容…

Mac软件介绍之录屏软件Filmage Screen

软件介绍 Filmage Screen 是一款专业的视频录制和编辑软件,适用于 Mac 系统 可以选择4k 60fps,可以选择录制电脑屏幕,摄像头录制,可以选择区域录制。同时也支持,简单的视频剪辑。 可以同时录制电脑麦克风声音 标准…

【Redis经典面试题七】Redis的事务机制是怎样的?

目录 一、Redis的事务机制 二、什么是Redis的Pipeline?和事务有什么区别? 三、Redis的事务和Lua之间有哪些区别? 3.1 原子性保证 3.2 交互次数 3.3 前后依赖 3.4 流程编排 四、为什么Lua脚本可以保证原子性? 五、为什么R…

【C++面向对象——群体类和群体数据的组织】实现含排序功能的数组类(头歌实践教学平台习题)【合集】

目录😋 任务描述 相关知识 1. 相关排序和查找算法的原理 2. C 类与成员函数的定义 3. 数组作为类的成员变量的处理 4. 函数参数传递与返回值处理 编程要求 测试说明 通关代码 测试结果 任务描述 本关任务: 将直接插入排序、直接选择排序、冒泡…

[开源]自动化定位建图系统

系统状态机: 效果展示: 1、 机器人建图定位系统-基础重定位,定位功能演示 2、 机器人建图定位系统-增量地图构建,手动回环检测演示 3、敬请期待… 开源链接: 1、多传感器融合里程计 https://gitee.com/li-wenhao-lw…

简单的jmeter数据请求学习

简单的jmeter数据请求学习 1.需求 我们的流程服务由原来的workflow-server调用wfms进行了优化,将wfms服务操作并入了workflow-server中,去除了原来的webservice服务调用形式,增加了并发处理,现在想测试模拟一下,在一…

conda/pip基本常用命令理解与整理

最近配置了两轮pytorch环境,由于要频繁用到各种conda和pip命令,所以再此整理一下。 文章目录 前言:conda虚拟环境总结与解读Conda和pip的理解区别和联系命令格式 conda环境命令查看创建和删除导出与导入激活和退出 包管理命令安装和删除文件批…

Maven 详细配置:Maven settings 配置文件的详细说明

Maven settings 配置文件是 Maven 环境的重要组成部分,它用于定义用户特定的配置信息和全局设置,例如本地仓库路径、远程仓库镜像、代理服务器以及认证信息等。settings 文件分为全局配置文件(settings.xml)和用户配置文件&#x…

Unity-Mirror网络框架-从入门到精通之Chat示例

文章目录 前言Chat聊天室Authentication授权ChatAuthenticatorChat示例中的授权流程聊天Chat最后 前言 在现代游戏开发中,网络功能日益成为提升游戏体验的关键组成部分。Mirror是一个用于Unity的开源网络框架,专为多人游戏开发设计。它使得开发者能够轻…

复杂园区网基本分支的构建

目录 1、各主机进行网络配置。2、交换机配置。3、配置路由交换,进行测试。4、配置路由器接口和静态路由,进行测试。5、最后测试任意两台主机通信情况 模拟环境链接 拓扑结构 说明: VLAN标签在上面的一定是GigabitEthernet接口的&#xff0c…

这是什么操作?强制迁移?GitLab 停止中国区用户访问

大家好,我是鸭鸭! 全球知名代码托管平台 GitLab 发布通告,宣布不再为位于中国大陆、香港及澳门地区的用户提供访问服务,并且“贴心”建议,可以访问极狐 GitLab。 极狐 GitLab 是一家中外合资公司,宣称获得…

CDP集成Hudi实战-spark shell

[〇]关于本文 本文主要解释spark shell操作Hudi表的案例 软件版本Hudi1.0.0Hadoop Version3.1.1.7.3.1.0-197Hive Version3.1.3000.7.3.1.0-197Spark Version3.4.1.7.3.1.0-197CDP7.3.1 [一]使用Spark-shell 1-配置hudi Jar包 [rootcdp73-1 ~]# for i in $(seq 1 6); do s…

设计模式学习[15]---适配器模式

文章目录 前言1.引例2.适配器模式2.1 对象适配器2.2 类适配器 总结 前言 这个模式其实在日常生活中有点常见,比如我们的手机取消了 3.5 m m 3.5mm 3.5mm的接口,只留下了一个 T y p e − C Type-C Type−C的接口,但是我现在有一个 3.5 m m 3.…

数据挖掘——数据预处理

数据挖掘——数据预处理 数据预处理数据预处理 ——主要任务数据清洗如何处理丢失的数据如何处理噪声数据如何处理不一致数据 数据集成相关分析相关系数(也成为皮尔逊相关系数)协方差 数据规约降维法:PCA主成分分析降数据——抽样法数据压缩 数据预处理 数据预处理…

Unity-Mirror网络框架-从入门到精通之CCU示例

文章目录 前言什么是CCU?测试结果最后 前言 在现代游戏开发中,网络功能日益成为提升游戏体验的关键组成部分。Mirror是一个用于Unity的开源网络框架,专为多人游戏开发设计。它使得开发者能够轻松实现网络连接、数据同步和游戏状态管理。本文…