Gemma模型论文详解(附源码)

news2024/10/7 20:30:48

原文链接:Gemma模型论文详解(附源码)

1. 背景介绍

Gemma模型是在2023.2.21号Google新发布的大语言模型, Gemma复用了Gemini相同的技术(Gemini也是Google发布的多模态模型),Gemma这次发布了了2B和7B两个版本的参数,不仅提供了预训练的checkpoints,还提供了用于对话、指令跟随等fine-tune的checkpoints。在QA问答、常识。在11

在这里插入图片描述

2. 模型介绍

2.1 模型结构

Gemma模型使用了transformer decoder结构进行训练,训练的上下文大小为8192个token,模型参数如下:
在这里插入图片描述

相比原始transformer结构的区别:

  • Multi-Query Attention:7B模型使用了multi-head attention,2B模型使用了multi-query attention (with 𝑛𝑢𝑚_𝑘𝑣_ℎ𝑒𝑎𝑑𝑠 = 1)。对比llama2中用了group-query attention
    在这里插入图片描述

  • RoPE Embeddings: 不使用绝对位置编码,在每一层前加下RoPE Embedding,同时共享输入与输出层的embedding权重。

  • GeGLU Activations: ReLU的激活替换为GeGLU的激活。对比llama中用了swiglu。

  • Normalizer Location: 在transformer的每一层layer的前后都进行规一化,这里使用RMSNorm做为规一化层。

2.2 训练搭建

Gemma使用TPUv5e进行训练;一个pod中有256块TPUv5e芯片,256块芯片被设计为16X16的2D拓扑;Gemma-7B使用16个pods(4096块卡)进行训练,Gemma-2B使用2个pods(512块卡)。7B模型在一个pod内使用16路模型并行和16路数据并行,2B模型在一个pod内使用256路数据并行。优化器状态使用ZeRO-3进行切分,减少显存占用。在pod外使用类似Pathways的方式减少数据复制的成本。

和Gemini模型训练一样,综合了Jax和Pathways的单控制器single controller编程范式,使用单个python进程编排整个训练; 使用GSPMD partitioner用于训练step的计算,使用XLA compiler减少中间结果的大小。

2.3 训练数据

Gemma 2B和7B分别基于2T和6T个token进行训练,token来源于纯英文的文本,内容包括网页、数学、代码等。使用SentencePiece的tokenizer,字典大小有256K个token。数据过滤使用基于模型的分类器去除有害的、低质量的内容。最后采用类似Gemini的方式进行训练数据的混合,提升高质量数据的占比。

2.4 指令微调(Instruction Tuning)

2B和7B进行有监督微调(SFT)训练中使用混合生成数据和人工标注的prompt文本对,同时进行RLHF训练。在SFT阶段,基于给定的一个prompt,通过测试模型生成多个响应的回答结果,通过一个更大更好的模型进行结果的好坏判断。基于不同的侧重方向(指令跟随/事实/创造性/安全等)构建不同的prompt。使用多种基于LM的自动判断方法,比如chain-of-thought prompting

训练和推理过程中使用相同的数据格式,格式的设计重点在于两点,一个是确定多轮对话中的角色,一个是确定一轮对话的开始结束。对应格式标记和示例的训练数据如下:

在这里插入图片描述
在这里插入图片描述

3. 源码

  • Tensorflow实现的源码在github google-deepmind/gemma中,PyTorch实现的源码在github google/gemma_pytorch。

  • 模型的配置在gemma/config.py文件中, 7B与2B区别主要在于num_hidden_layers/num_attention_heads/num_key_value_heads/hidden_size/intermediate_size

@dataclasses.dataclass
class GemmaConfig:
    # The number of tokens in the vocabulary.
    vocab_size: int = 256000
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 8192
    # The number of blocks in the model.
    num_hidden_layers: int = 28
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 16
    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 16
    # The hidden size of the model.
    hidden_size: int = 3072
    # The dimension of the MLP representations.
    intermediate_size: int = 24576
    # The number of head dimensions.
    head_dim: int = 256
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6
    # The dtype of the weights.
    dtype: str = 'bfloat16'
    # Whether a quantized version of the model is used.
    quant: bool = False
    # The path to the model tokenizer.
    tokenizer: Optional[str] = 'tokenizer/tokenizer.model'

    def get_dtype(self) -> Optional[torch.dtype]:
        """Gets the torch dtype from the config dtype string."""
        return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)


def get_config_for_7b() -> GemmaConfig:
    return GemmaConfig()


def get_config_for_2b() -> GemmaConfig:
    return GemmaConfig(
        num_hidden_layers=18,
        num_attention_heads=8,
        num_key_value_heads=1,
        hidden_size=2048,
        intermediate_size=16384
    )
  • 模型定义在gemma/model.py文件中,GemmaDecoderLayer的定义如下:
class GemmaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: gemma_config.GemmaConfig,
    ):
        super().__init__()
        self.self_attn = GemmaAttention(
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            quant=config.quant,
        )
        self.mlp = GemmaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant=config.quant,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
  • GeGLU的实现跟llama的swiglu不同,geglu相比glu区是采用了gelu的激活,以下是glu的计算示例图:
    在这里插入图片描述

代码参考如下,代码中self.gate_proj对应上图中的B矩阵,gate相当于 σ ( B ) \sigma(B) σ(B)self.up_proj对应上图中的A矩阵.

class GemmaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant: bool,
    ):
        super().__init__()
        self.gate_proj = Linear(hidden_size, intermediate_size, quant)
        self.up_proj = Linear(hidden_size, intermediate_size, quant)
        self.down_proj = Linear(intermediate_size, hidden_size, quant)

    def forward(self, x):
        gate = self.gate_proj(x)
        gate = F.gelu(gate)
        up = self.up_proj(x)
        fuse = gate * up
        outputs = self.down_proj(fuse)
        return outputs

4. 参考

  • google-deepmind/gemma
  • Gemma 开放模型
  • Gemma: Open Models Based on Gemini Research and Technology
  • gemma-open-models
  • github google/gemma_pytorch
  • github google-deepmind/gemma
  • Grouped Query Attention论文阅读
  • SwiGLU论文阅读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1466238.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文件上传漏洞--Upload-labs--Pass10--双写绕过

一、什么是双写绕过 顾名思义,双写绕过就是双写文件后缀名来进行绕过,如:test.php 双写后为 test.pphphp。通常情况下双写绕过用于绕过源代码中的 str_ireplace()函数。 二、双写绕过原理 1、首先进行代码审计,源代码中有黑名单…

linux---防火墙拓展

目录 一、iptables 1.基本语法 2.四表五链——重点记忆 2.1四表 2.2五链 2.3总结 3.iptables选项示例 3.1 -Z 清空流量计数 3.2 -P 修改默认规则 3.3 -D 删除规则 3.4 -R 指定编号替换规则 4.白名单 5.通用匹配 6.示例 6.1添加回环网卡 6.2可以访问端口 6.3 主…

ERROR: No matching distribution found for json

问题描述 安装 json库 的时候,一直报错: 解决方案: 大多数博文分享是:①网络问题,换国内镜像;②更新pip. 少有人提及在Python 3.10.1中,它叫 simplejson 了 pip install simplejson 参考&am…

力扣日记2.21-【回溯算法篇】46. 全排列

力扣日记:【回溯算法篇】46. 全排列 日期:2023.2.21 参考:代码随想录、力扣 46. 全排列 题目描述 难度:中等 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1&…

分布式应用:kylin 部署 zabbix 监控平台

目录 一、实验 1.环境 2. kylin 修改mysql数据库 3. kylin 部署 zabbix 监控平台 4. kylin 修改 zabbix 配置 5. kylin 修改zabbix web 二、问题 1. zabbix_server 查看版本报错 2.zabbix_server 文件如何去掉注释"#"和空行 3. zabbix图表显示异常 4.zabbi…

Docker基础篇(三) 容器数据卷(二) dockerfile

新建dockerfile文件 zenDockerfile from centos volume [“/containVolum-01”, “/containVolum-02”] CMD echo “zen”

YOLO v9 出世!

当今的深度学习方法专注于如何设计最合适的目标函数,以使模型的预测结果能够尽可能地接近真实值。同时,还需要设计一种适当的架构,以便为预测获取足够的信息。现有方法忽略了一个事实,即当输入数据经过逐层特征提取和空间转换时&a…

Java基于SpringBoot+Vue的体育用品库存管理系统,附源码

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

什么品牌的洗地机好用?入门级智能洗地机

对于隔三差五就需要做全屋清洁的家庭,使用传统拖布洗地真的很吃不消,随着科技的不断进步,洗地机成为现代家庭清洁的必备工具之一。洗地机,用最贴合实际省事、省钱的方式去完成家务劳动,可以大大减少体力消耗&#xff0…

信号信号槽

三、信号槽 概念 信号和槽是两种函数,这是Qt在C基础上新增的特性,类似于其他技术中的回调的概念。 信号槽通过程序员提前设定的“约定”,可以实现对象之间的通信,有两个先决条件。 通信的对象都是在QOBject类中派生出来的。 QOBje…

Linux环境非root用户配置SSH免密登录,并解决登录仍提示输入密码

Linux环境非root用户配置SSH免密登录,并解决登录仍提示输入密码 ssh免密登录的简单理解 以A和B进行举例:A免密登录B (即在A服务器输入命令:ssh 非root用户名B的IP地址)可以直接免密码直接登录 A生成私钥和公钥&#…

​​​​​​​Sora:OpenAI的革命性AI视频模型与其对未来影像创作的影响

随着深度学习技术和计算能力的进步,人工智能不仅在图像识别、自然语言处理等领域取得了卓越成就,同时也在不断突破视频处理和生成的边界。在这一背景下,OpenAI推出了Sora——一种新型的AI视频模型,标志着AI在视频内容创作领域的又…

云呐智能维运技术有哪些?智能运维活动有哪些

智能运维(AIOps)技术是指利用人工智能、机器学习、大数据分析等先进技术手段,来提高IT运维效率和质量的一系列技术和工具。目前常见的智能运维技术核心功能和应用场景。一些具体的智能运维活动案例,包括但不限于故障预测、自动化修…

[hgame 2024 week3] crypto/pwn

第2周作完了不知道扔哪去了,先记录下第3周,因为官方WP已经出来,顺便把没出的题复现一下。最近的比赛都比较不错,相当于近期知识点的总结,有点心经的意思。 Crypto matrix_equation 题目很短,结了一个式子…

数据可视化在商业领域有哪些重要性?

数据可视化在商业领域的重要性体现在多个方面,它通过将复杂的数据集转化为直观、易于理解的图形和图表,帮助企业和组织做出更明智的决策。以下是数据可视化对商业的一些关键重要性: 提高决策效率:通过直观的图表和图形&#xff0c…

防御保护第八、九、十、十一天笔记

一、内容安全 1、DFI和DPI技术 --- 深度检测技术 DPI是一种基于应用层的流量检测和控制技术,它会对流量进行拆包,分析包头和应用层的内容,从而识别应用程序和应用程序的内容。这种技术增加了对应用层的分析,识别各种应用&#xf…

主流开发语言和开发环境:探索编程世界的基础

在当今这个快速发展的技术时代,软件开发已经成为推动创新的重要力量。无论是构建下一代应用、开发先进的算法还是创建复杂的系统,选择合适的编程语言和开发环境都是至关重要的。在本文中,我们将探讨当前流行的几种主流开发语言以及它们常用的…

敏捷项目管理在现代软件开发中的应用

在现代软件开发领域,项目管理起着至关重要的作用。随着技术的不断进步和市场需求的快速变化,传统的项目管理方法已逐渐无法满足软件开发的需求。因此,敏捷项目管理应运而生,成为许多软件开发团队的首选方法。本文将探讨敏捷项目管…

基于Python3的数据结构与算法 - 04 快速排序

一、快速排序思路 快速排序特点:快 步骤: 取一个元素p(第一个元素),使元素p归为;列表被p分成两部分,左边都比p小,右边都比p大;递归完成排序。 因此我们可以得到快速排…

数字化转型导师坚鹏:政府数字化转型案例研究(包括省市政府)

政府数字化转型案例研究(包括省市政府) 课程背景: 很多地方政府存在以下问题: 不清楚标杆省政府数字化转型的成功案例 不清楚直辖市政府数字化转型的成功案例 不清楚地级市政府数字化转型的成功案例 课程特色&#xff1a…