ChatGLM2 源码解析:`ChatGLMModel`

news2024/11/27 15:46:50

# 完整的 GLM 模型,包括嵌入层、编码器、输出层
class ChatGLMModel(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
        super().__init__(config)
        # 如果设置了`empty_init`,创建任何 PyTorch 模块时,不初始化参数
        if empty_init:
            init_method = skip_init
        else:
            init_method = default_init
        init_kwargs = {}
        if device is not None:
            init_kwargs["device"] = device
        # 单词嵌入层
        self.embedding = init_method(Embedding, config, **init_kwargs)
        # LC
        self.num_layers = config.num_layers
        # GC
        self.multi_query_group_num = config.multi_query_group_num
        # HS
        self.kv_channels = config.kv_channels

        # SL
        self.seq_length = config.seq_length
        rotary_dim = (
            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
        )
        # 位置嵌入(PE)
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        # GLM 编码器
        self.encoder = init_method(GLMTransformer, config, **init_kwargs)
        # 输出层
        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
                                        dtype=config.torch_dtype, **init_kwargs)
        self.pre_seq_len = config.pre_seq_len
        self.prefix_projection = config.prefix_projection
        if self.pre_seq_len is not None:
            # 如果设置了前缀序列长度(PSL)
            # 关闭所有参数的自动梯度
            for param in self.parameters():
                param.requires_grad = False
            # [0, 1, ..., PSL - 1]
            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
            # 初始化前缀编码层和 Dropout
            self.prefix_encoder = PrefixEncoder(config)
            self.dropout = torch.nn.Dropout(0.1)

    def get_input_embeddings(self):
        return self.embedding.word_embeddings

    def get_prompt(self, batch_size, device, dtype=torch.half):
        # prefix_tokens = [0, 1, ..., PSL - 1]
        # [PSL] => [1, PSL] => [BS, PSL]
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
        # [BS, PSL, KVS=NL * HS * 2GC]
        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
        # [BS, PSL, KVS=NL * HS * 2GC] => [BS, PSL, 2NL, GC, HS]
        past_key_values = past_key_values.view(
            batch_size,
            self.pre_seq_len,
            self.num_layers * 2,
            self.multi_query_group_num,
            self.kv_channels
        )
        
        past_key_values = self.dropout(past_key_values)
        # [BS, PSL, 2NL, GC, HS] => [2NL, PSL, BS, GC, HS] => NL * [2, PSL, BS, GC, HS]
        past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
        return past_key_values

    def forward(
            self,
            input_ids,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.BoolTensor] = None,
            full_attention_mask: Optional[torch.BoolTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 输入是单词 ID,的形状为 [BS, SL]
        batch_size, seq_length = input_ids.shape
        # 将单词 ID 传递给词嵌入层得到嵌入向量
        if inputs_embeds is None:
            inputs_embeds = self.embedding(input_ids)

        # 如果设置了 PSL
        if self.pre_seq_len is not None:
            # 如果没有提供 KV 缓存,初始化为前 PSL 个前缀的词嵌入
            if past_key_values is None:
                past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
                                                  dtype=inputs_embeds.dtype)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
                                            attention_mask], dim=-1)

        if full_attention_mask is None:
            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

        # 计算 PE
        # 初始化位置编码层
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        # 如果提供了位置 ID 就是用它检索位置嵌入矩阵
        # 如果没有,就返回嵌入矩阵的前 SL 个向量
        if position_ids is not None:
            rotary_pos_emb = rotary_pos_emb[position_ids]
        else:
            rotary_pos_emb = rotary_pos_emb[None, :seq_length]
        # [BS, SL, ES] => [SL, BS, ES]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        # 将词嵌入和位置嵌入传给编码器得到编码器输出
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
        )

        # 返回 GLM 输出,每层的 KV 缓存和每层的输出
        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

    def quantize(self, weight_bit_width: int):
        from .quantization import quantize
        quantize(self.encoder, weight_bit_width)
        return self

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/976459.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【软件测试】单元测试、集成测试、系统测试有什么区别?

单元测试、集成测试、系统测试有什么区别 1、粒度不同 集成测试bai粒度居中,单元测试粒度最小,系统du测试粒度最大。 2、测试方式不同 集成测试一般由开发zhi小组采用白盒加黑盒的方式来测试,单元测试一般由开发小组采用白盒方式来测试&a…

第六章:中华民族的抗日战争

1.日本发动灭亡中国的侵略斗争 关键字: 中国抗日战争的起点与全民族抗战阶段 2.中国人民奋起抗击日本侵略者 关键字: 1 国共第二次统一战线初步建立的标志:国民党五届三中全会 2 扭转时局的枢纽,国内和平初步实现:…

3.msfconle

目录 1 进入msfconsole 2 连接postgresql数据库 3 msfconsole基本用法 4 更新msf 5 搜索脚本 search 6 查看脚本信息 info 7 设置参数 8 重新设置参数与取消参数 9 退出当前模块 back 10 查看域名基本信息 dig 11 查看域名的详细信息 whois 1 进入msfco…

k8s部署redis 3主3从

k8s部署redis6节点,组成3主3从集群模式 一般来说,redis部署有三种模式。 单实例模式,一般用于测试环境。 哨兵模式 集群模式后两者用于生产部署 哨兵模式 在redis3.0以前,要实现集群一般是借助哨兵sentinel工具来监控master节点…

BeanUtils.copyProperties:曾经是我的女神,现在是我的毒药。

前言 BeanUtils.copyProperties十有八九是你这些年工作中用的很多的其中一个,不管是Apache的还是Spring的。 网上的解释浩如烟海,我这边用一个超简单的例子直观展示给你看。 以后就记住了,能不用就不用。 正文 1、网上的解释 我收纳了几个网…

HDFS 架构剖析

目录 一、HDFS 架构整体概述 二、HDFS 集群角色介绍 2.1 整体概述 2.2 主角色:namenode 2.3 从角色:datanode 2.4 主角色辅助角色: secondarynamenode 三、HDFS 重要特性 3.1 主从架构 3.2 分块存储机制 3.3 副本机制 3.4 …

基于React实现:弹窗组件与Promise的有机结合

背景 弹窗在现代应用中是最为常见的一种展示信息的形式,二次确认弹窗是其中最为经典的一种。当我们在React,Vue这种数据驱动视图的前端框架中渲染弹窗基本是固定的使用形式。 使用方式:创建新的弹窗组件,在需要弹窗的地方引用并…

百叶帘系统内置于玻璃内,分为手动和电动两种控制方式

百叶帘系统是一种在餐厅包厢隔断墙中常见的控制窗帘或遮光帘的方式。这种系统通常分为手动和电动两种控制方式,具体选择取决于您的需求和预算。 1. 手动控制:手动控制是传统的方式,通过手动操作绳子或杆来打开或关闭百叶帘。这是一种经济实惠…

力扣刷题49 字母 异位词分组

目录 题目描述代码实现基本实现优化代码 基础知识回溯集合 参考 题目描述 给你一个字符串数组,请你将 字母异位词 组合在一起。可以按任意顺序返回结果列表。 字母异位词 是由重新排列源单词的所有字母得到的一个新单词。 示例 1: 输入: strs [“eat”, “tea”…

19|返璞归真:王维佛系建议,万事不如吃好睡好

好诗相伴,千金不换。你好,我是天博。 今天我们的主题仍然是“见自己”。其实,诗词里并不是只有诗情画意的春花秋月,也充满了实实在在的人间烟火。这些现实的生活对我们平常人来说,往往比春花秋月更有借鉴意义。我们今…

基于Java+SpringBoot+Vue前后端分离在线考试系统设计和实现

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

蝶形运算法

蝶形运算法是一种基于FFT(Fast Fourier Transform)算法的计算方法,其基本思想是将长度为N的DFT分解成若干个长度为N/2的DFT计算,并通过不断的合并操作得到最终的结果。该算法也称为“蝴蝶算法”,因为它的计算过程中需要…

未加载Qt6Core.pdb

编译代码发现未加载.pdb。 问题是Qt6的bin目录下也没有.pdb文件啊? 有两种方法,一是把Qt路径放到环境变量下,这样系统就能找到Qt6需要的依赖项。 二是在生成目录下找到编译好的.exe文件,然后调用windeployqt.exe为其生成依赖项&…

通用策略06丨横截面因子在期货中的应用(2)

量化策略开发,高质量社群,交易思路分享等相关内容 大家好,今天为大家带来2023年度通用系列的收官之作——再议横截面因子。 在通用05策略中,我们以一种很简单的框架和复现方式,为大家展示了横截面因子在期货中的运用展…

5. 本地方法接口和本地方法栈

5.1. 什么是本地方法? 简单地讲,一个Native Method是一个Java调用非Java代码的接囗。一个Native Method是这样一个Java方法:该方法的实现由非Java语言实现,比如C。这个特征并非Java所特有,很多其它的编程语言都有这一…

云贝餐饮连锁独立版 v2.7.9+公众号+小程序端+全插件(免授权前端线传)安装教程

云贝餐饮连锁版主要基于目前比较流行小程序生态下的自助点单系统,一款非常不错的餐饮外卖小程序。播播资源测试云贝餐饮连锁独立版 v2.7.9该版本与上一版一样永久授权版,增加了小程序前端线传功能(通过其他第三方上传)&#xff0c…

jdk17下netty导致堆内存疯涨原因排查 | 京东云技术团队

背景: 介绍 天网风控灵玑系统是基于内存计算实现的高吞吐低延迟在线计算服务,提供滑动或滚动窗口内的count、distinctCout、max、min、avg、sum、std及区间分布类的在线统计计算服务。客户端和服务端底层通过netty直接进行tcp通信,且服务端…

JS判断对象是否发生变化,常用于监听页面表单是否修改并给出保存提示

本文主要封装方法,实现用户离开表单编辑页面时弹出提示框,若表单数据发生变化,则提示用户是否保存当前页面的信息,如图: 封装方法: /*** 比较俩个对象之间的差异,项目中多处用到监听表单数据是…

配电室能耗数据采集系统

随着社会的快速发展,能源消耗逐年增加,能源问题已成为制约我国经济社会发展的瓶颈。在此背景下,节能减排、绿色发展成为国家战略,而配电室作为电力系统的重要组成部分,其能耗管理对整个电力系统的能效有着举足轻重的影…

SQL sever中表数据管理

目录 一、插入数据: 二、更新数据: 三、删除数据: 四、清空数据: 4.1使用DELETE语句: 4.2 使用TRUNCATE TABLE语句: 4.3区别: 4.3.1DELETE FROM: 4.3.2TRUNCATE TABLE&am…