Pytorch深度解析:Transformer嵌入层源码逐行解读

news2024/10/5 20:27:41

前言

本部分博客需要先阅读博客:
《Transformer实现以及Pytorch源码解读(一)-数据输入篇》
作为知识储备。

Embedding使用方式

如下面的代码中所示,embedding一般是先实例化nn.Embedding(vocab_size, embedding_dim)。实例化的过程中输入两个参数:vocab_size和embedding_dim。其中的vocab_size是指输入的数据集合中总共涉及多少个去重后的单词;embedding_dim是指,每个单词你希望用多少维度的向量表示。随后,实例化的embedding在forward中被调用self.embeddings(inputs)。

class Transformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
                 dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
        super(Transformer, self).__init__()
        # 词嵌入层
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
        # 编码层:使用Transformer
        encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        # 输出层
        self.output = nn.Linear(hidden_dim, num_class)

    def forward(self, inputs, lengths):
        inputs = torch.transpose(inputs, 0, 1)
        hidden_states = self.embeddings(inputs)
        hidden_states = self.position_embedding(hidden_states)
        attention_mask = length_to_mask(lengths) == False
        hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
        logits = self.output(hidden_states)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs

数据被怎样变换了?

如下图所示,第一个tensor表示input,该input表示一个句子( sentence),只是该句子中的单词用整数进行了代替,相同的整数表示相同的单词。而每个1在embedding之后,变成了相同过的向量。

我们将以上的代码重新的运行一遍,发现表示1的向量改变了,这说明embedding 的过程不是确定的,而是随机的。

数据是怎样被变化的?

Embedding类在调用过程中主要涉及到以下几个核心方法:_
init
,rest_parameters,forward:

Embedding类的初始化过程如下所示。当_weight没有的情况下调用Parameter初始化一个空的向量,该向量的维度与输入数据中的去重单词个数(num_bembeddings)一样。然后调用reset_parameters方法。

 def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
                 sparse: bool = False, _weight: Optional[Tensor] = None,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
            elif padding_idx < 0:
                assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
                padding_idx = self.num_embeddings + padding_idx
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        if _weight is None:
            self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
            # print("===========================================1")
            # print(self.weight)
            #将self.weight进行nornal归一化
            self.reset_parameters()
            print("===========================================2")
            print(self.weight)
        else:
            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
                'Shape of weight does not match num_embeddings and embedding_dim'
            self.weight = Parameter(_weight)

        self.sparse = sparse

reset_parameters的实现如下所示,主要是调用了init.norma_方法。

    def reset_parameters(self) -> None:
        init.normal_(self.weight)
        self._fill_padding_idx_with_zero()

init.normal_又调用了torch.nn.init中的normal方法。该方法将空的self.weight矩阵填充为一个符合 (0,1)正太分布的矩阵。

N

(

mean

,

std

2

)

.

\mathcal{N}(\text{mean}, \text{std}^2).

N

(

mean

,

std

2

)

.

def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
    r"""Fills the input Tensor with values drawn from the normal
    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.normal_(w)
    """
    return _no_grad_normal_(tensor, mean, std)

继续追踪_no_grad_normal_(tensor, mean, std)我们发现,该方法是通过c++实现,所在的源码文件目录为:

namespace torch {
namespace nn {
namespace init {
namespace {
struct Fan {
  explicit Fan(Tensor& tensor) {
    const auto dimensions = tensor.ndimension();
    TORCH_CHECK(
        dimensions >= 2,
        "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");

    if (dimensions == 2) {
      in = tensor.size(1);
      out = tensor.size(0);
    } else {
      in = tensor.size(1) * tensor[0][0].numel();
      out = tensor.size(0) * tensor[0][0].numel();
    }
  }

  int64_t in;
  int64_t out;
};
Tensor normal_(Tensor tensor, double mean, double std) {
  NoGradGuard guard;
  return tensor.normal_(mean, std);
}

forward方法的c++实现如下所示。

torch::Tensor EmbeddingImpl::forward(const Tensor& input) {
  return F::detail::embedding(
      input,
      weight,
      options.padding_idx(),
      options.max_norm(),
      options.norm_type(),
      options.scale_grad_by_freq(),
      options.sparse());
}

继续追踪,发现weight中的每个变量被下面的c++代码填充了正太分布的随机数。

void normal_kernel(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {
  CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
  templates::cpu::normal_kernel(self, mean, std, generator);
}

随机数的生成调用如下的代码,首先询问:目前代码是在什么设备上运行,并调用cpu或者gup上的随机数生成方法。

template <typename T>
static inline T * check_generator(c10::optional<Generator> gen) {
  TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");
  TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");
  TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");
  return gen->get<T>();
}

/**
 * Utility function used in tensor implementations, which
 * supplies the default generator to tensors, if an input generator
 * is not supplied. The input Generator* is also static casted to
 * the backend generator type (CPU/CUDAGeneratorImpl etc.)
 */
template <typename T>
static inline T* get_generator_or_default(const c10::optional<Generator>& gen, const Generator& default_gen) {
  return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
}

至此,embedding的每个随机数的生成过程都清楚了。

总结

Embedding的过程,其实就是为每个单词对应一个向量的过程。该向量为(0,1)正太分布,该矩阵在Embedding的实例化过程就已经被初始化完成。在调用Embedding示例的时候即forward开始工作的时候,只是做了一个匹配的过程,也就是将<字典,向量>的对应关系应用到input上。前期解读该部分源码的困惑是一只找不到forward中的对应处理过程,以为embedding的处理逻辑是在forward的阶段展开的,显然这种想法是不对的。Pytorch的架构设计的的确优雅!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1831189.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Comfy UI使用最新SD3模型,并解决报错‘NoneType‘ object has no attribute ‘tokenize‘【实测可行】

解决Comfy UI使用最新SD3模型报错’NoneType’ object has no attribute ‘tokenize’ 前几天SD3发布了&#xff0c;所以想着尝尝鲜&#xff0c;便去下载了SD3来玩一玩。使用的是Comfy UI而不是Stable Diffusion UI&#xff0c;这是一个比SD UI更加灵活的UI界面&#xff0c;使用…

【Gradio】Building With Blocks 块中的状 态 + 动态应用程序与渲染装饰器

State in Blocks 块中的状态 我们介绍了接口中的状态&#xff0c;这个指南将看看块中的状态&#xff0c;其工作原理大致相同。 全局状态 块中的全局状态与接口中的工作原理相同。在函数调用外创建的任何变量都是所有用户共享的引用。 会话状态 Gradio 支持会话状态&#xff…

大模型-人类病理学的语言视觉AI助手

论文摘要翻译与评论 论文标题&#xff1a; A Multimodal Generative AI Copilot for Human Pathology 摘要翻译&#xff1a; 计算病理学领域已经在任务特定的预测模型和任务无关的自监督视觉编码器的发展方面取得了显著进展。然而&#xff0c;尽管生成性人工智能快速增长&a…

DBA常用论坛

1.ITPUB ITPUB技术论坛_专业的IT技术社区 2.ASKTOM Ask TOM

APP IOS

APP IOS苹果源生应用程序 APP Android-CSDN博客

编写乘法器求解算法表达式

描述 编写一个4bit乘法器模块&#xff0c;并例化该乘法器求解c12*a5*b&#xff0c;其中输入信号a,b为4bit无符号数&#xff0c;c为输出。注意请不要直接使用*符号实现乘法功能。 模块的信号接口图如下&#xff1a; 要求使用Verilog HDL语言实现以上功能&#xff0c;并编写tes…

scrapy模块的基础使用

scrapy模块是爬虫工作者最常用的一个模块之一&#xff0c;因它有许多好用的模板&#xff0c;和丰富的中间件&#xff0c;深受欢迎。 一&#xff0c;scrapy的安装 可以通过pypi的指引进行安装 在终端内输入以下代码&#xff1a; pip install scrapy 二&#xff0c;项目的建…

【学习笔记】MySQL(Ⅱ)

MySQL(Ⅱ) 7、 进阶篇 —— 存储引擎 7.1、MySQL 体系结构 7.2、存储引擎 7.2.1 InnoDB 7.2.2 MyISAM 7.2.3 Memory 7.2.4 InnoDB、MyISAM、Memory 的比较8、 拓展篇 —— 在 Linux 上安装数据库9、进阶篇 —— 索引 …

SmartEDA、Multisim、Proteus大比拼:电路设计王者之争?

在电路设计领域&#xff0c;SmartEDA、Multisim和Proteus无疑是三款备受瞩目的软件工具。它们各自拥有独特的功能和优势&#xff0c;但在这场电路设计王者的竞争中&#xff0c;谁才是真正的领跑者&#xff1f;让我们深入探究这三款软件的异同&#xff0c;揭示它们各自的魅力所在…

rabbitMQ的简单使用

rabbitMQ的介绍 RabbitMQ是一个开源的消息代理和队列服务器&#xff0c;主要用于在不同的应用程序之间传递消息。它基于AMQP&#xff08;Advanced Message Queuing Protocol&#xff09;协议&#xff0c;提供了一种可靠的方式来处理异步通信。RabbitMQ使用Erlang语言编写&…

【VUE3学习手札】

VUE3学习手札 vue3成长之路学习笔记 文章目录 VUE3学习手札前言一、markRaw1.1 代码示例1.2 应用场景1.3 拓展&#xff08;toRaw&#xff09;1.4 实际应用 前言 主要用于自己的一个备忘&#xff0c;对知识点的查缺补漏 一、markRaw 将一个对象标记为不可被转为代理。返回该对象…

北京大学数字普惠金融指数(2011-2022年)

北京大学数字普惠金融指数&#xff08;2011-2022年&#xff09;&#xff0c;包含省市县三级数据 数据年限&#xff1a;省级、地级市&#xff08;2011-2022年&#xff09;&#xff1b;区县&#xff08;2014-2022年&#xff09; 数据格式&#xff1a;excel、pdf 数据来源&#xf…

java-数据结构与算法-02-数据结构-01-数组

文章目录 1. 概述2. 动态数组3. 二维数组4. 局部性原理5. 越界检查6. 习题 1. 概述 定义 在计算机科学中&#xff0c;数组是由一组元素&#xff08;值或变量&#xff09;组成的数据结构&#xff0c;每个元素有至少一个索引或键来标识 In computer science, an array is a dat…

subversion

subversion Install # CentOS安装Subversion yum install subversion mkdir /var/svn/ systemctl restart svnserve# Docker安装Subversion&#xff08;参考&#xff1a;https://github.com/garethflowers/docker-svn-server&#xff09; docker run \--name my-svn-server \…

气体传感器的工作原理探究

气体传感器的工作原理主要基于其内部的感应元件与目标气体之间的相互作用。不同的气体传感器可能采用不同的工作原理&#xff0c;但其核心目的都是将气体的浓度或成分转化为可测量和处理的电信号。 PID气体传感器 以常见的电化学式气体传感器为例&#xff0c;其工作原理涉及气体…

T113 Tina5.0 添加板级支持包

文章目录 环境介绍Tina5.0 SDK说明添加buildroot板级支持包添加板级支持包修改配置文件验证 添加openwrt板级支持包添加板级支持包修改配置文件验证其它 总结 环境介绍 硬件&#xff1a;韦东山T113工业板 软件&#xff1a;全志Tina 5.0 Tina5.0 SDK说明 需要明确的是&#x…

深度解析量水堰:结构、分类与设计要点

量水堰&#xff0c;作为水工测量中的关键设施&#xff0c;其精确度和多样性对于水位和流量的测量至关重要。其工作原理基于通过堰顶断面上的进水口&#xff0c;将水位引导至堰体内部&#xff0c;从而实现水位和流量的平衡。量水堰通常采用高强度、耐久的材料构建&#xff0c;如…

算法:分治(快排)题目练习

目录 题目一&#xff1a;颜色分类 题目二&#xff1a;排序数组 题目三&#xff1a;数组中的第k个最大元素 题目四&#xff1a;库存管理III 题目一&#xff1a;颜色分类 给定一个包含红色、白色和蓝色、共 n 个元素的数组 nums &#xff0c;原地对它们进行排序&#xff0c;…

Linux_应用篇(19) V4L2 摄像头应用编程

ALPHA/Mini I.MX6U 开发板配套支持多种不同的摄像头&#xff0c;包括正点原子的 ov5640&#xff08;500W 像素&#xff09;、ov2640&#xff08;200W 像素&#xff09;以及 ov7725&#xff08;不带 FIFO、 30W 像素&#xff09;这三款摄像头&#xff0c;在开发板出厂系统上&…

Jupyter Notebook简介

目录 1.概述 2.诞生背景 3.历史版本 4.安装 5.卸载 6.如何使用 7.菜单和菜单项 8.示例 9.未来展望 10.总结 1.概述 Jupyter Notebook是一种基于Web的交互式计算环境&#xff0c;主要用于数据分析、数据科学、机器学习以及探索性编程等领域。允许用户在单个文档中编写…