nn.embedding函数详解(pytorch)

news2024/10/6 20:25:23

提示:文章附有源码!!!

文章目录

  • 前言
  • 一、nn.embedding函数解释
  • 二、nn.embedding函数使用方法
  • 四、模型训练与预测的权重变化探讨


前言

最近发现prompt工程(如sam模型),也有transform的detr模型等都使用了nn.Embedding函数,对points、boxes或learn query进行编码或解码。因此,我想写一篇文章作为记录,本想简单对其 介绍,但写着写着就想把所有与它相关东西作为记录。本文章探讨了nn.Embedding参数、使用方法、模型训练与预测的变化,并附有列子源码作为支撑 ,呈现一个较为完善的理解内容。

一、nn.embedding函数解释

Embedding实际是一个索引表或查找表,它是符合随机初始化生成的正太分布的表,将输入向量化,其结构如下:

nn.Embedding(num_embeddings, embedding_dim)

第1个参数 num_embeddings 就是生成num_embeddings个嵌入向量。
第2个参数 embedding_dim 就是嵌入向量的维度,即用embedding_dim值的维数来表示一个基本单位。

当然,该函数还有很多其它参数,解释如下:

参数源码注释如下:

num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
                             therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
                             i.e. it remains as a fixed "pad". For a newly constructed Embedding,
                             the embedding vector at :attr:`padding_idx` will default to all zeros,
                             but can be updated to another value to be used as the padding vector.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
                            is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
                                        the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
                         See Notes for more details regarding sparse gradients.

参数中文解释:

num_embeddings (python:int) – 词典的大小尺寸,比如总共出现5000个词,那就输入5000。此时index为(0-4999embedding_dim (python:int) – 嵌入向量的维度,即用多少维来表示一个符号。
padding_idx (python:int, optional) – 填充id,比如,输入长度为100,但是每次的句子长度并不一样,后面就需要用统一的数字填充,而这里就是指定这个数字,这样,网络在遇到填充id时,就不会计算其与其它符号的相关性。(初始化为0max_norm (python:float, optional) – 最大范数,如果嵌入向量的范数超过了这个界限,就要进行再归一化。
norm_type (python:float, optional) – 指定利用什么范数计算,并用于对比max_norm,默认为2范数。
scale_grad_by_freq (boolean, optional) – 根据单词在mini-batch中出现的频率,对梯度进行放缩。默认为False.
sparse (bool, optional) – 若为True,则与权重矩阵相关的梯度转变为稀疏张量

注:该函数服从正太分布,该函数可参与训练,我将在后面做解释。

二、nn.embedding函数使用方法

该函数实际是对词的编码,假如你有2句话,每句话有四个词,那么你想对每个词使用6个维度表达,其代码如下:

import torch.nn as nn
import torch
if __name__ == '__main__':
   embedding = nn.Embedding(100, 6)  # 我设置100个索引,每个使用6个维度表达。
   input = torch.LongTensor([[1, 2, 4, 5],
   [4, 3, 2, 3]])  # a batch of 2 samples of 4 indices each
   e = embedding(input)
   print('输出尺寸', e.shape)
   print('输出值:\n',e)
   weights=embedding.weight
   print('embed权重输出值:\n', weights[:6])
     

输出结果:
在这里插入图片描述

从图上可看出,输入编码是通过索引查找已编号embedding的权重,并将其赋值替换表达。换句话说,nn.Embedding(100, 6)生成正太分布100行6列数据,行必须超过输入句子词语长度,而句子每个词使用整数编码成索引,该索引对应之前embedding行寻找,得到对应行
维度,即可转为表达该词的特征向量。

四、模型训练与预测的权重变化探讨

之前已说过nn.Embedding()在训练过程中会发生变化,但在预测中将不在变化,应该是被训练成最佳词的向量维度表达,也就是说每个词唯一对应索引,被Embedding特征表达训练成最佳特征表达,也可说训练词索引特征表达固定。为探讨此过程,我写了对应示列,如下:

import torch
from torch.nn import Embedding

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = Embedding(5, 3)
    def forward(self,vec):
        input = torch.tensor([0, 1, 2, 3, 4])
        emb_vec1 = self.emb(input)
        # print(emb_vec1)  ### 输出对同一组词汇的编码
        output = torch.einsum('ik, kj -> ij', emb_vec1, vec)
        return output
def simple_train():
    model = Model()
    vec = torch.randn((3, 1))
    label = torch.Tensor(5, 1).fill_(3)
    loss_fun = torch.nn.MSELoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.015)
    print('初始化emebding参数权重:\n',model.emb.weight)
    for iter_num in range(100):
        output = model(vec)
        loss = loss_fun(output, label)
        opt.zero_grad()
        loss.backward(retain_graph=True)
        opt.step()
        # print('第{}次迭代emebding参数权重{}:\n'.format(iter_num, model.emb.weight))
    print('训练后emebding参数权重:\n',model.emb.weight)
    torch.save(model.state_dict(),'./embeding.pth')
    return model

def simple_test():
   model = Model()
   ckpt = torch.load('./embeding.pth')
   model.load_state_dict(ckpt)

   model=model.eval()
   vec = torch.randn((3, 1))

   print('加载emebding参数权重:\n', model.emb.weight)
   for iter_num in range(100):
      output = model(vec)
   print('n次预测后emebding参数权重:\n', model.emb.weight)

if __name__ == '__main__':
   simple_train()  # 训练与保存权重
   simple_test()

结果如下:

在这里插入图片描述
训练代码参考博客:点击这里

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1178526.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

时序预测 | MATLAB实现基于SVM-Adaboost支持向量机结合AdaBoost时间序列预测

时序预测 | MATLAB实现基于SVM-Adaboost支持向量机结合AdaBoost时间序列预测 目录 时序预测 | MATLAB实现基于SVM-Adaboost支持向量机结合AdaBoost时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 1.Matlab实现SVM-Adaboost时间序列预测(风…

解决ERR: cURL error 77: Unable to initialize NSS: -8023

研发反映一个问题,上传文件时失败,日志内错误信息如下: ERR: cURL error 77: Unable to initialize NSS: -8023 (SEC_ERROR_PKCS11_DEVICE_ERROR) ... 这个功能使用了腾讯云的点播服务。因此立即联系了腾讯云客服。 搞了很久问题依旧。 反复测试,发现上传视频文件,错…

03【保姆级】-GO语言变量和数据类型和相互转换

03【保姆级】-GO语言变量和数据类型 一、变量1.1 变量的定义:1.2 变量的声明、初始化、赋值1.3 变量使用的注意事项 插播-关于fmt.Printf格式打印%的作用二、 变量的数据类型2.1整数的基本类型2.1.1 有符号类型 int8/16/32/642.1.2 无符号类型 int8/16/32/642.1.3 整…

涉及法律诉讼和负债670万美元的【工务园】申请纳斯达克IPO上市

来源:猛兽财经 作者:猛兽财经 猛兽财经获悉,总部位于东莞的人力资源SaaS平台Baiya International Group Inc(简称:工务园)近期已向美国证券交易委员会(SEC)提交招股书,申…

【pytorch源码分析--torch执行流程与编译原理】

背景 解读torch源码方便算子开发方便后续做torch 模型性能开发 基本介绍 代码库 https://github.com/pytorch/pytorch 模块介绍 aten: A Tensor Library的缩写。与Tensor相关的内容都放在这个目录下。如Tensor的定义、存储、Tensor间的操作(即算子/OP&#xff…

StripedFly恶意软件:悄无声息运行5年,感染百万设备

导语:最近,俄罗斯网络安全公司Kaspersky发布的一项调查显示,一种名为StripedFly的高级恶意软件伪装成加密货币挖矿程序,悄无声息地在全球范围内运行了超过5年,感染了100万台设备。这是一种复杂的模块化框架&#xff0c…

【Unity实战】最全面的库存系统(二)(附源码)

文章目录 先来看看最终效果前言箱子库存箱子存储物品玩家背包快捷栏满了,物品自动加入背包修复开着背包拾取物品不会刷新显示的问题将箱子库存和背包分开,可以同时打开源码完结先来看看最终效果 前言 本期紧跟着上期,继续来完善我们的库存系统,实现箱子库存和人物背包 箱…

RocketMq简介及安装、docker安装rocketmq、安装rocketmq可视化管理端

前言 本文主要简单介绍rocketmq及使用docker安装rocketmq的方法。 rocketmq简介 rocketmq有两部分,nameserver和broker,nameserver用来维护broker的地址、向生产者、消费者推送broker的最新地址;broker用来存储、转发消息;也就…

Java根据一个List内Object的两个字段去重

背景 在Java开发过程中,我们经常会遇到需要对List进行去重的需求。 其中常见的情况是,将数组去重,或者将对象依据某个字段去重。这两种方式均可用set属性进行处理。 今天讨论,有一个List,且其中的元素是自定义的对象&…

Linux学习第34天:Linux LCD 驱动实验(一):星星之火可以燎原

Linux版本号4.1.15 芯片I.MX6ULL 大叔学Linux 品人间百味 思文短情长 LCD显示屏是由一个一个的像素点构成的。当你能控制一个像素点的亮暗及颜色变化的时候,你就能让LCD显示瓶显示五颜六色的整幅图案。甚至可以让LCD屏幕…

uboot启动linux kernel的流程

目录 前言流程图autoboot_commandrun_command_listdo_bootmdo_bootm_statesdo_bootm_linuxboot_prep_linuxboot_jump_linux 前言 本文在u-boot启动流程分析这篇文章的基础上,简要梳理uboot启动linux kernel的流程。 流程图 其中, autoboot_command位于…

合成数据的被需要的5 个重要原因

若要训练机器学习模型,需要数据。数据科学任务通常不是 Kaggle 竞赛,在竞赛中,你有一个很好的大型策划数据集,并预先标记。有时,您必须收集、组织和清理自己的数据。在现实世界中收集和标记数据的过程可能非常耗时、繁…

手拿5份offer,最高18k! 95后艺术生转行后台网优,这个火花有点大!

当艺术生碰上理工科,会有怎样的火花?在大众的刻板认知里,艺术和理工科就像两条很少重合的平行线,双方从业者在自己的行业下按部就班,规划未来。 来自东北长春的W同学却打破了常人的认知,身为美术老师的他却…

沿面闪络放电测量装置中的真空度精密控制解决方案

摘要:针对现有低气压环境下沿面闪络测试中存在真空度无法精确控制所带来的一系列问题,特别是针对用户提出的对现有沿面闪络试验装置的真空控制系统进行技术改造要求,本文提出了相应的技改方案,技改方案采用基于动态平衡法的电动针…

民生银行与CRM系统的无代码开发集成,助力用户运营

连接民生银行与CRM系统的无代码开发集成 中国民生银行股份有限公司,成立于1996年,是一家全国性股份制商业银行。民生银行拥有强大的技术实力和丰富的业务经验,通过与各类企业进行深度合作,帮助企业实现财务管理和客服系统的优化运…

BUUCTF easycap 1

BUUCTF:https://buuoj.cn/challenges 题目描述: 下载附件,解压得到一个.pcap文件。 密文: 解题思路: 1、这道题和它的名字一样,真的很easy。双击easycap.pcap文件,打开Wireshark。在Wireshark中&#xf…

【软件工程】程序流程图之绘图工具和教程推荐

2023年11月6日,周一晚上 目录 绘图工具推荐教程推荐 绘图工具推荐 我推荐使用开源免费的draw.io要绘制程序流程图 draw.io网页版地址:Flowchart Maker & Online Diagram Software draw.io桌面版下载地址:GitHub - jgraph/drawio-desk…

MySQL的备份恢复

数据备份的重要性 1.生产环境中,数据的安全至关重要 任何数据的丢失都会导致非常严重的后果。 2.数据为什么会丢失 :程序操作,运算错误,磁盘故障,不可预期的事件(地震,海啸)&#x…

使用cpolar配合Plex搭建私人媒体站并实现远程访问

文章目录 1.前言2. Plex网站搭建2.1 Plex下载和安装2.2 Plex网页测试2.3 cpolar的安装和注册 3. 本地网页发布3.1 Cpolar云端设置3.2 Cpolar本地设置 4. 公网访问测试5. 结语 1.前言 用手机或者平板电脑看视频,已经算是生活中稀松平常的场景了,特别是各…

Nginx默认会自动忽略请求头Headers里带下划线_的参数

起因:该接口设置了必须要传送app_code和app_secret才能正常访问。实际我在本地环境测试中,发现该接口是正常访问的,但是部署到正式系统之后发现,该接口一直提示app_code和app_secret不能为空。 后续排查:发现正式系统…