Transformer实现以及Pytorch源码解读(二)-embedding源码分析

news2025/2/6 14:06:38

前言

本部分博客需要先阅读博客:《Transformer实现以及Pytorch源码解读(一)-数据输入篇》 作为知识储备。

Embedding使用方式

如下面的代码中所示,embedding一般是先实例化nn.Embedding(vocab_size, embedding_dim)。实例化的过程中输入两个参数:vocab_size和embedding_dim。其中的vocab_size是指输入的数据集合中总共涉及多少个去重后的单词;embedding_dim是指,每个单词你希望用多少维度的向量表示。随后,实例化的embedding在forward中被调用self.embeddings(inputs)。

class Transformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
                 dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
        super(Transformer, self).__init__()
        # 词嵌入层
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
        # 编码层:使用Transformer
        encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        # 输出层
        self.output = nn.Linear(hidden_dim, num_class)

    def forward(self, inputs, lengths):
        inputs = torch.transpose(inputs, 0, 1)
        hidden_states = self.embeddings(inputs)
        hidden_states = self.position_embedding(hidden_states)
        attention_mask = length_to_mask(lengths) == False
        hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
        logits = self.output(hidden_states)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs

数据被怎样变换了?

如下图所示,第一个tensor表示input,该input表示一个句子( sentence),只是该句子中的单词用整数进行了代替,相同的整数表示相同的单词。而每个1在embedding之后,变成了相同过的向量。
在这里插入图片描述
我们将以上的代码重新的运行一遍,发现表示1的向量改变了,这说明embedding 的过程不是确定的,而是随机的。

在这里插入图片描述

数据是怎样被变化的?

Embedding类在调用过程中主要涉及到以下几个核心方法:_init,rest_parameters,forward:
在这里插入图片描述
Embedding类的初始化过程如下所示。当_weight没有的情况下调用Parameter初始化一个空的向量,该向量的维度与输入数据中的去重单词个数(num_bembeddings)一样。然后调用reset_parameters方法。

 def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
                 max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
                 sparse: bool = False, _weight: Optional[Tensor] = None,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Embedding, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
            elif padding_idx < 0:
                assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
                padding_idx = self.num_embeddings + padding_idx
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        if _weight is None:
            self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
            # print("===========================================1")
            # print(self.weight)
            #将self.weight进行nornal归一化
            self.reset_parameters()
            print("===========================================2")
            print(self.weight)
        else:
            assert list(_weight.shape) == [num_embeddings, embedding_dim], \
                'Shape of weight does not match num_embeddings and embedding_dim'
            self.weight = Parameter(_weight)

        self.sparse = sparse

reset_parameters的实现如下所示,主要是调用了init.norma_方法。

    def reset_parameters(self) -> None:
        init.normal_(self.weight)
        self._fill_padding_idx_with_zero()

init.normal_又调用了torch.nn.init中的normal方法。该方法将空的self.weight矩阵填充为一个符合 (0,1)正太分布的矩阵。
N ( mean , std 2 ) . \mathcal{N}(\text{mean}, \text{std}^2). N(mean,std2).

def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
    r"""Fills the input Tensor with values drawn from the normal
    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.normal_(w)
    """
    return _no_grad_normal_(tensor, mean, std)

继续追踪_no_grad_normal_(tensor, mean, std)我们发现,该方法是通过c++实现,所在的源码文件目录为:

namespace torch {
namespace nn {
namespace init {
namespace {
struct Fan {
  explicit Fan(Tensor& tensor) {
    const auto dimensions = tensor.ndimension();
    TORCH_CHECK(
        dimensions >= 2,
        "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");

    if (dimensions == 2) {
      in = tensor.size(1);
      out = tensor.size(0);
    } else {
      in = tensor.size(1) * tensor[0][0].numel();
      out = tensor.size(0) * tensor[0][0].numel();
    }
  }

  int64_t in;
  int64_t out;
};
Tensor normal_(Tensor tensor, double mean, double std) {
  NoGradGuard guard;
  return tensor.normal_(mean, std);
}

forward方法的c++实现如下所示。

torch::Tensor EmbeddingImpl::forward(const Tensor& input) {
  return F::detail::embedding(
      input,
      weight,
      options.padding_idx(),
      options.max_norm(),
      options.norm_type(),
      options.scale_grad_by_freq(),
      options.sparse());
}

总结

Embedding的过程,其实就是为每个单词对应一个向量的过程。该向量为(0,1)正太分布,该矩阵在Embedding的实例化过程就已经被初始化完成。在调用Embedding示例的时候即forward开始工作的时候,只是做了一个匹配的过程,也就是将<字典,向量>的对应关系应用到input上。前期解读该部分源码的困惑是一只找不到forward中的对应处理过程,以为embedding的处理逻辑是在forward的阶段展开的,显然这种想法是不对的。Pytorch的架构设计的的确优雅!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/102895.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

业务安全情报,预知发现黑灰产的企业攻击

业务遭遇欺诈风险&#xff0c;发起攻击的黑灰产主要是为了谋取利益。对于黑灰产利益目的甄别需要多方面情报&#xff0c;再辅助技术和专家经验&#xff0c;然后进行综合判断&#xff0c;进而帮助企业及时响应、精准布控。 安全情报帮助企业提前发现攻击 12月13日&#xff0c;“…

谁代表先进生产力?

互联网企业代表先进生产力方向 做软件项目的时候碰到三类企业 1、 传统企业&#xff0c;以卖货或卖服务为主 2、 互联网类&#xff0c;做个服务工具或平台 3、 分配模式类&#xff0c;以某分配模式为“宝贝” 毫无疑问&#xff1a; 互联网企业代表先进的生产力方向 互联网类…

即时通讯音视频开发之音频基础及编码原理

即时通讯应用中的实时音视频技术&#xff0c;几乎是IM开发中的最后一道高墙。原因在于&#xff1a;实时音视频技术 音视频处理技术 网络传输技术 的横向技术应用集合体&#xff0c;而公共互联网不是为了实时通信设计的。 比特率&#xff1a; 表示经过编码&#xff08;压缩&am…

C#读取Excel文件内容(WPS)

本地安装的WPS版本为 一、下载accessdatabaseengine_X64后安装 网址&#xff1a;https://www.microsoft.com/en-us/download/details.aspx?id54920 二、项目中引用OleDb包 三、代码部分 //excelFilePath为文件路径&#xff08;例如D:\Test.xslx&#xff09; > //strin…

Android---RecyclerView回收复用机制

一、RecyclerView回收复用 回收什么&#xff1f;复用什么&#xff1f; 回收&#xff1a;回收即缓存。当屏幕上的一个itemView滑出屏幕(即不可见了)&#xff0c;RecyclerView就利用回收机制&#xff0c;将该itemView放入内存。当其它itemView出现时&#xff0c;不用每次都去new…

JavaScript-Sass

Sass的基础使用 1.简介 1.1简介 Sass是世界上最成熟&#xff0c;最稳定&#xff0c;最强大的CSS扩展语言Sass是css预编译工具可以更加优雅的书写csssass写出来的东西浏览器不认识需要进行转换VSCode推荐使用Easy Sass插件Sass中可以使用加减乘除&#xff0c;条件分支以及循环…

【Three.js入门】处理动画、尺寸自适应、双击进入/退出全屏(Clock跟踪时间,Gsap动画库,自适应画面,进入/退出全屏)

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;也会涉及到服务端 &#x1f4c3;个人状态&#xff1a; 在校大学生一枚&#xff0c;已拿多个前端 offer&#xff08;秋招&#xff09; &#x1f680;未…

Python -- 网络编程

目录 1.网络通信的概念 2.IP地址 3.网络通信方式 3.1 直接通信 3.2 使用集线通信 3.3 通用交换机通信 3.4 使用路由器连接多个网络 3.5 复杂的通信过程 4.端口 4.1 端口号 4.2 知名端口号 4.3 动态端口号 4.4 端口号作用 5.socker概念 5.1 不同电脑上的进程之间…

【二叉树经典习题讲解】

If you find a path with no obstacles, probably doesnt lead anywhere. 目录 1 前中后序遍历一颗二叉树 2 总的结点个数 3 求叶子节点个数 4 求树的高度 5 第k层结点个数 6 二叉树的层序遍历 7 判断一棵树是否为完全二叉树 1 二叉树的前序遍历 2 单值二叉树 3 翻转二…

2022卡塔尔世界杯的两个球员:一个吸螺,一个没吸

你好&#xff0c;我是YourBatman&#xff1a;一个俗人&#xff0c;贪财好色。 2022年12月18日&#xff0c;卢塞尔球场&#xff0c;太太太精彩了&#xff0c;这场世界杯决赛&#xff01;卡塔尔世界杯&#xff0c;已经离我们远去&#xff0c;阿根廷最终满载而归。 那一个个珍贵…

大脑的默认模式网络DMN

虽然默认模式网络DMN现在是rs-fMRI领域中的研究热点&#xff0c;但最初观察到默认模式网络的工具是PET&#xff0c;并且是从任务态过渡到静息态的 PET中大脑功能活动基线的定义&#xff1a; 基线是理解复杂系统的基础根据脑氧提取分数&#xff08;OEF值&#xff09;可以确定正…

前端CSS实现跳动的文字

效果图 首选来一个简单的布局 这里就不用多说&#xff0c;都是简单排版 <h1>一个爬坑的Coder</h1>html {height: 100%; }body {display: flex;justify-content: center;align-items: center;height: 100%; } h1 {font-size: 48px; }每个文字独立出来 每个文字都…

最全GIS开发编程语言汇总及分类

推荐查看>>>科研所需模型软件教程&#xff1a;水文水资源、大气科学、农林生态、地信遥感、统计分析、编程语言等... 最近总有很多人关心GIS开发语言的问题&#xff0c;这个确实很重要&#xff0c;毕竟学习一门编程语言需要花费不少时间和精力&#xff0c;找不到合适…

【数据库】并发控制理论

并发控制&#xff08;concurrency control&#xff09; 恢复&#xff08;recovery) 理论支持&#xff1a;基于事务的ACID Atomicity: All actions in the txn happend, or none happen. “All or nothing” Consistency: IF each txn is consistent and the DB starts consis…

【Google语音转文字】Speech to Text 超级好用的语音转文本API

前面有一篇博客说到了讯飞输入法&#xff0c;支持语音输入&#xff0c;也支持电脑内部音源输入&#xff0c;详细参考&#xff1a;【实时语音转文本】PC端实时语音转文本(麦克风外音&系统内部音源) 但是它只是作为一个工具来使用&#xff0c;如果我们想自己做一些好玩的东西…

CANoe-VN5000接口卡在Network-based模式下典型的应用场景

1、Network-based mode说明 CANoe软硬件都需要设置为Network-based mode 软件从CANoe12版本支持Network-based模式(CANoe12时称为Port-based mode,从13开始改为Network-based mode) 硬件从VN5000系列开始支持Network-based模式,VN5610A和VN5640设备需要确保切换到Network…

必读干货|使用Cmake管理C++项目简明教程

一、背景 Cmake是 kitware公司以及一些开源开发者在开发几个工具套件(VTK)的过程中衍生品&#xff0c;最终形成体系&#xff0c;成为一个独立的开源项目。其官方网站是 cmake.org&#xff0c;可以通过访问官方网站获得更多关于cmake的信息。 它是一个跨平台的编译(Build)工具…

【大数据存储技术】「#3」将数据从Hive导入到MySQL

文章目录准备工作安装Hive、MySQL和SqoopHive预操作启动MySQL、hadoop、hive创建临时表inner_user_log和inner_user_info使用Sqoop将数据从Hive导入MySQL启动hadoop集群、MySQL服务将前面生成的临时表数据从Hive导入到 MySQL 中查看MySQL中user_log或user_info表中的数据准备工…

网页爬虫的本质

1.网页结构分析 提取其中一部分核心介绍 &#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title> </head> <body> <div class"item masonry-brick&quo…

数据管理篇之计算管理

第13章 计算管理 目的&#xff1a;降低计算资源的消耗&#xff0c;提高任务执行的性能&#xff0c;提升任务产出的时间。 1.系统优化 HBO HBO &#xff08;History-Based Optimizer&#xff0c;基于历史的优化&#xff09;是根据任务历史执行情况为任务分配更合理的资源&…