【HuggingFace Transformers】OpenAIGPTModel源码解析

news2024/9/28 13:21:09

OpenAIGPTModel源码解析

  • 1. GPT 介绍
  • 2. OpenAIGPTModel类 源码解析

说到ChatGPT,大家可能都使用过吧。2022年,ChatGPT的推出引发了广泛的关注和讨论。这款对话生成模型不仅具备了强大的语言理解和生成能力,还能进行非常自然的对话,给用户带来了全新的互动体验。然而,ChatGPT的成功背后离不开它的前身——GPT

1. GPT 介绍

GPT(Generative Pre-trained Transformer)是由OpenAI开发的一种基于Transformer架构的大型语言模型。它由多个堆叠的自注意力解码器层(Transformer Blocks)组成,每一层包含多头自注意力机制和前馈神经网络,并配有残差连接和层归一化以稳定训练。GPT采用自回归方式生成文本,通过在大规模互联网数据上进行预训练,具备强大的自然语言理解和生成能力,能够完成对话生成、文本补全等多种任务。其结构如下:

在这里插入图片描述

2. OpenAIGPTModel类 源码解析

源码地址:transformers/src/transformers/models/openai/modeling_openai.py

# -*- coding: utf-8 -*-
# @time: 2024/9/3 20:39
from typing import Optional, Union, Tuple

import torch

from torch import nn
from transformers import add_start_docstrings, OpenAIGPTPreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.openai.modeling_openai import OPENAI_GPT_START_DOCSTRING, Block, OPENAI_GPT_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings


@add_start_docstrings(
    "The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
    OPENAI_GPT_START_DOCSTRING,
)
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)  # 定义 token 嵌入层
        self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)  # 定义 position 嵌入层
        self.drop = nn.Dropout(config.embd_pdrop)  # 定义 drop 层
        self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)]) # 定义多个 Block 层

        # 注册一个缓冲区用于存储position_ids,初始化为从 0 到 config.n_positions 的序列
        self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.tokens_embed

    def set_input_embeddings(self, new_embeddings):
        self.tokens_embed = new_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        # 剪掉模型多头注意力机制中的一些头,heads_to_prune 是一个字典,键为layer_num,值为需要剪枝的 heads 列表。
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

    @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_DOC,
        output_type=BaseModelOutput,
        config_class=_CONFIG_FOR_DOC,
    )
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
        # 根据 config 配置设定 output_attentions, output_hidden_states, return_dict 的值
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 获取 input_ids 或者 inputs_embeds 以及 input_shape
        if input_ids is not None and inputs_embeds is not None:  # 当 input_ids 和 inputs_embeds 同时存在时,抛出错误
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:  # 如果存在 input_ids,将其形状调整为 (batch_size, sequence_length)
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
        elif inputs_embeds is not None:  # 如果存在 inputs_embeds,获取其形状
            input_shape = inputs_embeds.size()[:-1]
        else:  # 如果 input_ids 和 inputs_embeds 都不存在,抛出错误
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        # 如果没有传入 position_ids,则生成默认的 position_ids
        if position_ids is None:
            # Code is different from when we had a single embedding matrix from position and token embeddings
            position_ids = self.position_ids[None, : input_shape[-1]]

        # ------------------------------------- 1. 获取 attention_mask -----------------------------#
        # Attention mask.
        if attention_mask is not None:
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # 将 2D 掩码扩展为 3D 掩码,适用于批量输入

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and the dtype's smallest value for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            # 将注意力掩码转换为与模型参数相同的数据类型,并进行数值变换,torch.finfo(self.dtype).min 返回数据类型的最小值。
            attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
        # ----------------------------------------------------------------------------------------#

        # ------------------------------------- 2. 获取 head_mask ---------------------------------#
        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
        # ---------------------------------------------------------- -----------------------------#

        # ------------------------------------- 3. 获取 hidden_states -----------------------------#
        # 如果 inputs_embeds 为 None,则使用 tokens_embed 对 input_ids 计算
        if inputs_embeds is None:
            inputs_embeds = self.tokens_embed(input_ids)
        # 计算 position_embeds
        position_embeds = self.positions_embed(position_ids)
        # 如果存在 token_type_ids,使用 tokens_embed 计算;否则 token_type_embeds 为 0
        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
            token_type_embeds = self.tokens_embed(token_type_ids)
        else:
            token_type_embeds = 0
        # 计算 hidden_states,即inputs_embeds、position_embeds 和 token_type_embeds 之和,并使用 dropout
        hidden_states = inputs_embeds + position_embeds + token_type_embeds
        hidden_states = self.drop(hidden_states)
        # -------------------------------------------------------------------------------------#

        # 获取输出形状,以及初始化输出结果 all_attentions 和 all_hidden_states
        output_shape = input_shape + (hidden_states.size(-1),)
        all_attentions = () if output_attentions else None
        all_hidden_states = () if output_hidden_states else None

        # -----------------------------------4. Block逐层计算处理(核心部分)--------------------#
        for i, block in enumerate(self.h):
            # 如果需要输出 hidden states,将当前 hidden_states 添加到 all_hidden_states
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            # 通过当前 Block 处理 hidden_states,得到新的 hidden_states 和 attentions
            outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
            hidden_states = outputs[0]
            # 如果需要输出 attentions,将当前 attentions 添加到 all_attentions
            if output_attentions:
                all_attentions = all_attentions + (outputs[1],)
        # ---------------------------------------------------------------------------------#

        # 将 hidden_states 的形状调整为输出形状
        hidden_states = hidden_states.view(*output_shape)

        # 如果需要输出 hidden states,将最后的 hidden_states 添加到 all_hidden_states
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        # -----------------------------------5. 根据配置的输出方式输出结果-------------------------------#
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)

        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2104115.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手机免费录屏软件,这3款软件最佳选择

在数字化浪潮的推动下,智能手机已成为我们生活中不可或缺的一部分。而在这些小巧而强大的设备中,录屏功能逐渐崭露头角,成为记录屏幕精彩瞬间的得力助手。无论是游戏的高光时刻、APP的使用教程,还是进行远程会议,录屏功…

2024自动化测试面试真题(附答案)!

一、编程语法题 1 、 python 有哪些数据类型 python 数据类型有很多,基本数据类型有整型(数字)、字符串、元组、列表、字典和布尔类型等 2 、怎么将两个字典合并 调用字典的 update 方法,合并 2 个字典。 3 、 json.l python 如…

HarmonyOS NEXT 体验调用云数据库更新排行榜单

一、介绍 基于鸿蒙Next模拟一个排行帮单二、场景需求 1.目标用户 社交平台用户,尤其是热衷于获取和分享信息的年轻人和用户群体。 2. 功能描述 用户可以通过“排行帮单”功能查看某个主题或领域的热门内容,并能够向朋友或群体推荐特定的项目。 3. 需求…

数据治理与标准推动数据成为“金矿”

方案介绍: 数据治理是一个涉及组织、政策、流程和技术的综合性框架,旨在确保数据的质量、安全性、可用性、合规性和一致性。它涵盖了从数据产生到销毁的全生命周期管理,确保数据在组织内部得到正确、高效地使用。而数据标准是数据治理的基石…

OPenCV结构分析与形状描述符(2)计算轮廓周长的函数arcLength()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 计算轮廓的周长或曲线的长度。 该函数计算曲线的长度或闭合轮廓的周长。 如果曲线是闭合的(即首尾相连),则计…

CSS解析:层叠、优先级和继承

CSS虽说不是编程语言,但是日常使用中经常有很多误解,发现样式不奏效的情况,所以需要加强下CSS基础。 CSS本质上就是声明规则,即在各种条件下,我们希望产生特定的效果。 如果某个元素有这个类,则应用这些样…

英文翻译哪家强?2024年3款热门工具大比拼

现在世界变得越来越“小”,英语几乎成了大家都懂的语言。但对那些天天忙工作的小伙伴们来说,一大堆英文的东西,比如文件、邮件、会议记录,看着就头大。好在,科技帮了大忙,出了好多翻译工具。2024年&#xf…

php邮箱服务器怎么搭建?如何构建服务器?

php邮箱服务器配置教程指南?php邮件服务器如何搭建? 搭建一个稳定高效的php邮箱服务器,不仅可以提升邮件传输的效率,还能增强数据的安全性。那么,如何着手搭建这样一个服务器呢?AokSend将详细探讨php邮箱服…

使用YOLOv10训练自定义数据集之一(环境部署)

0x00 前言 由清华大学的研究团队基于 Ultralytics Python 包研发的 YOLOv10,通过优化模型结构并去除非极大值抑制(NMS)环节,提出了一种创新的实时目标检测技术。这些改进不仅实现了行业领先的检测性能,还降低了对计算…

网络编程----网络基础ip地址

一丶IP地址 1.基本概念 1. IP地址是Internet中主机的标识 2. Internet中的主机要与别的机器通信必须具有一个IP地址 3. IP地址为32位(IPv4)或者128位(IPv6) NAT:公网转私网、私网转公网 4. IPV4表示形式&…

【简历】25届上海某一本JAVA简历:第一次看学校背景写一页的

注:为保证用户信息安全,姓名和学校等信息已经进行同层次变更,内容部分细节也进行了部分隐藏 简历说明 这是一份25 届上海某一本大学硕士的Java简历。这份简历写得比较偏,让人头疼。 这位同学的学校是重点一本,可以冲…

C++第四十五弹---深入理解包装器:提升代码复用性与安全性的利器

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】 目录 1 包装器 1.1、function包装器 1.2、bind 1 包装器 1.1、function包装器 function包装器 也叫作适配器。C中的function本质是一个类模板&…

uniapp树洞烦恼分享系统 微信小程序设计与实现 80igt

目录 博主介绍技术栈系统设计🌟文末获取源码数据库🌟具体实现截图后端前端java类核心代码部分展示可行性论证个人心得系统测试操作可行性源码获取详细视频演示 博主介绍 👇🏻 博主介绍:👇🏻 专…

使用 WARP 和 Perf 测试对 MinIO 企业对象存储进行基准测试

AI/ML、高级分析和数据库等现代应用程序需要高性能对象存储。MinIO Enterprise Object Store 将可扩展性和高性能相结合,使每个工作负载(无论要求多么苛刻)触手可及。我们发布的基准测试表明,MinIO Enterprise Object Storage 是市…

泰克Tektronix MSO46 一款混合信号示波器

Tektronix MSO46 是一款混合信号示波器 (MSO),专为调试和分析复杂的电子电路而设计。FlexChannel 技术使每个通道输入都可以用作单个模拟通道、八个数字逻辑输入(使用 TLP058 逻辑探头)或同时使用模拟和频谱视图,每个域都有独立的…

前端进阶|一文理解柯里化的逆操作,什么是反柯里化

温故而知新 在说反柯里化之前,先来复习下柯里化的基础。之前文章,我们了解了什么是柯里化,以及柯里化的实现原理,同时我们也明白了什么情况下我们使用柯里化,详细阅读参见之前文章《前端进阶|由浅入深的理…

探索Python世界的趣味之旅:自制贪吃蛇游戏

通过本次贪吃蛇游戏的开发实践,不仅可以掌握Python编程语言的基础知识,还深入了解了游戏开发的基本流程和技术要点。这只是一个开始,Python的世界远不止于此。未来,你可以尝试开发更复杂、更有趣的游戏项目,甚至探索人…

Java详解String 字符串类以及String内存原理、StringBuilder类、StringJoiner类(附有代码+案例)

文章目录 九.String 字符串类型9.0 String概述9.1 字符串常用方法9.2 String内存图9.2.1直接赋值9.2.2new出来 9.3字符串比较9.4 字符串遍历9.4.1 统计字符串大小写及数字9.4.2 拼接字符串9.4.3字符串反转 9.5 StringBuilder类9.5.1StringBuilder 构造方法9.5.2StringBuilder常…

Spring全局异常处理HandlerExceptionResolver使用

1 引言 全局异常处理在项目中经常会用到,主要作用包括统一处理异常、提供友好的错误信息、避免应用程序崩溃、记录异常日志、避免异常信息泄露等等。下文将以实现HandlerExceptionResolver接口的方式,实现全局异常处理功能及常规用法。 2 代码 下面列…

Qt 字符串的编码方式,以及反斜杠加3个数字是什么编码\344\275\240,如何生成

Qt 字符串的编码方式 问题 总所周知,Qt的ui文件在编译时,会自动生成一个ui_xxxxx.h的头文件,打开一看,其实就是将摆放的控件new出来以及布局的代码。 只要用Qt提供的uic.exe工具,自己也可以将ui文件输出为代码文件…