PyTorch嵌入层(nn.Embedding)

news2025/4/4 11:49:00

在 PyTorch 中,nn.Embedding 层(即 model.user_embedding)除了 .weight 这个核心属性外,还有其他属性和方法。以下是完整的解析:


1. 主要属性

(1) weight(核心参数)
  • 作用:存储所有嵌入向量的可训练权重矩阵。
  • 形状(num_embeddings, embedding_dim)
  • 示例
    print(model.user_embedding.weight.shape)  # 输出:torch.Size([3, 4])
    
(2) num_embeddings
  • 作用:返回嵌入向量的总数(即用户/物品的数量)。
  • 示例
    print(model.user_embedding.num_embeddings)  # 输出:3
    
(3) embedding_dim
  • 作用:返回每个嵌入向量的维度。
  • 示例
    print(model.user_embedding.embedding_dim)  # 输出:4
    
(4) padding_idx(可选)
  • 作用:如果设置了 padding_idx,则对应的嵌入向量会被强制设为 0 且不参与训练。
  • 示例
    # 初始化时设置 padding_idx=0
    self.user_embedding = nn.Embedding(3, 4, padding_idx=0)
    print(model.user_embedding.padding_idx)  # 输出:0
    print(model.user_embedding.weight[0])    # 输出:tensor([0., 0., 0., 0.], grad_fn=<SelectBackward>)
    

2. 主要方法

(1) forward(input)
  • 作用:根据输入的 ID 返回对应的嵌入向量。
  • 示例
    input_ids = torch.tensor([0, 1, 2])  # 查询用户 0、1、2 的向量
    embeddings = model.user_embedding(input_ids)  # 返回 shape (3, 4)
    
(2) reset_parameters()
  • 作用:重新随机初始化权重(通常在训练前调用)。
  • 内部逻辑:默认使用均匀分布 U ( − k , k ) U(-\sqrt{k}, \sqrt{k}) U(k ,k ),其中 k = 1 embedding_dim k = \frac{1}{\text{embedding\_dim}} k=embedding_dim1
  • 示例
    model.user_embedding.reset_parameters()
    
(3) extra_repr()
  • 作用:返回层的额外信息(用于 print 时显示)。
  • 示例
    print(model.user_embedding.extra_repr())  
    # 输出:'num_embeddings=3, embedding_dim=4'
    

3. 其他底层属性(一般无需直接操作)

  • _parameters:存储所有可训练参数(包括 weight)。
  • _buffers:存储非可训练参数(如 BatchNorm 的 running_mean)。
  • training:布尔值,表示是否处于训练模式。

4. 完整属性/方法列表

可以通过 dir() 查看所有属性和方法:

print(dir(model.user_embedding))

输出示例:

['__class__', '__delattr__', '__dir__', ..., 'weight', 'num_embeddings', 'embedding_dim', 'padding_idx', 'forward', 'reset_parameters']

5. 关键总结

属性/方法用途示例值/调用方式
.weight核心权重矩阵shape=(3, 4)
.num_embeddings嵌入向量的总数(用户数)3
.embedding_dim每个向量的维度4
.padding_idx指定填充索引(可选)None0
.forward(input)查询嵌入向量model.user_embedding([0, 1])
.reset_parameters()重新初始化权重model.user_embedding.reset_parameters()

6. 常见问题

Q:如何修改嵌入向量?
  • 直接操作 .weight
    # 将用户 0 的向量置零
    model.user_embedding.weight.data[0] = torch.zeros(4)
    
Q:如何冻结嵌入层?
  • 禁用梯度:
    model.user_embedding.weight.requires_grad = False
    
Q:padding_idx 和普通索引有什么区别?
  • padding_idx 对应的向量会固定为 0,且不参与梯度更新。

掌握这些属性和方法后,你可以更灵活地操作嵌入层! 🚀

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2327885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《AI大模型应知应会100篇》加餐篇:LlamaIndex 与 LangChain 的无缝集成

加餐篇&#xff1a;LlamaIndex 与 LangChain 的无缝集成 问题背景&#xff1a;在实际应用中&#xff0c;开发者常常需要结合多个框架的优势。例如&#xff0c;使用 LangChain 管理复杂的业务逻辑链&#xff0c;同时利用 LlamaIndex 的高效索引和检索能力构建知识库。本文在基于…

元素三大等待

硬性等待&#xff08;强制等待&#xff09; 线程休眠&#xff0c;强制等待 Thread.sleep(long millis);这是最简单的等待方式&#xff0c;使用time.sleep()方法来实现。在代码中强制等待一定的时间&#xff0c;不论元素是否已经加载完成&#xff0c;都会等待指定的时间后才继…

【DY】信息化集成化信号采集与处理系统;生物信号采集处理系统一体机

MD3000-C信息化一体机生物信号采集处理系统 实验平台技术指标 01、整机外形尺寸&#xff1a;1680mm(L)*750mm(w)*2260mm(H)&#xff1b; 02、实验台操作面积&#xff1a;750(w)*1340(L&#xff09;&#xff08;长*宽&#xff09;&#xff1b; 03、实验台面离地高度&#xf…

康谋分享 | 仿真驱动、数据自造:巧用合成数据重构智能座舱

随着汽车向智能化、场景化加速演进&#xff0c;智能座舱已成为人车交互的核心承载。从驾驶员注意力监测到儿童遗留检测&#xff0c;从乘员识别到安全带状态判断&#xff0c;座舱内的每一次行为都蕴含着巨大的安全与体验价值。 然而&#xff0c;这些感知系统要在多样驾驶行为、…

Vue 数据传递流程图指南

今天&#xff0c;我们探讨一下 Vue 中的组件传值问题。这不仅是我们在日常开发中经常遇到的核心问题&#xff0c;也是面试过程中经常被问到的重要知识点。无论你是初学者还是有一定经验的开发者&#xff0c;掌握这些传值方式都将帮助你更高效地构建和维护 Vue 应用 目录 1. 父…

【C语言】strstr查找字符串函数

一、函数介绍 strstr 是 C 语言标准库 <string.h> 中的字符串查找函数&#xff0c;用于在主字符串中查找子字符串的首次出现位置。若找到子串&#xff0c;返回其首次出现的地址&#xff1b;否则返回 NULL。它是处理字符串匹配问题的核心工具之一。 二、函数原型 char …

机器学习、深度学习和神经网络

机器学习、深度学习和神经网络 术语及相关概念 在深入了解人工智能&#xff08;AI&#xff09;的工作原理以及它的各种应用之前&#xff0c;让我们先区分一下与AI密切相关的一些术语和概念&#xff1a;人工智能、机器学习、深度学习和神经网络。这些术语有时会被交替使用&#…

数字孪生在智慧城市中的前端呈现与 UI 设计思路

一、数字孪生技术在智慧城市中的应用与前端呈现 数字孪生技术通过创建城市的虚拟副本&#xff0c;实现了对城市运行状态的实时监控、分析与预测。在智慧城市中&#xff0c;数字孪生技术的应用包括交通流量监测、环境质量分析、基础设施管理等。其前端呈现主要依赖于Web3D技术、…

Android OpenGLES 360全景图片渲染(球体内部)

概述 360度全景图是一种虚拟现实技术&#xff0c;它通过对现实场景进行多角度拍摄后&#xff0c;利用计算机软件将这些照片拼接成一个完整的全景图像。这种技术能够让观看者在虚拟环境中以交互的方式查看整个周围环境&#xff0c;就好像他们真的站在那个位置一样。在Android设备…

LETTERS(DFS)

【题目描述】 给出一个rowcolrowcol的大写字母矩阵&#xff0c;一开始的位置为左上角&#xff0c;你可以向上下左右四个方向移动&#xff0c;并且不能移向曾经经过的字母。问最多可以经过几个字母。 【输入】 第一行&#xff0c;输入字母矩阵行数RR和列数SS&#xff0c;1≤R,S≤…

NVM 多版本Node.js 管理全指南(Windows系统)

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家、全栈领域优质创作者、高级开发工程师、高级信息系统项目管理师、系统架构师&#xff0c;数学与应用数学专业&#xff0c;10年以上多种混合语言开发经验&#xff0c;从事DICOM医学影像开发领域多年&#xff0c;熟悉DICOM协议及…

C,C++语言缓冲区溢出的产生和预防

缓冲区溢出的定义 缓冲区是内存中用于存储数据的一块连续区域&#xff0c;在 C 和 C 里&#xff0c;常使用数组、指针等方式来操作缓冲区。而缓冲区溢出指的是当程序向缓冲区写入的数据量超出了该缓冲区本身能够容纳的最大数据量时&#xff0c;额外的数据就会覆盖相邻的内存区…

《Linux内存管理:实验驱动的深度探索》【附录】【实验环境搭建 2】【vscode搭建调试内核环境】

1. 如何调试我们的内核 1. GDB调试 安装gdb sudo apt-get install gdb-multiarchgdb-multiarch是多架构版本&#xff0c;可以通过set architecture aarch64指定架构 QEMU参数修改添加-s -S #!/usr/bin/shqemu-7.2.0-rc1/build/aarch64-softmmu/qemu-system-aarch64 \-nogr…

Flutter项目之登录注册功能实现

目录&#xff1a; 1、页面效果2、登录两种状态界面3、中间按钮部分4、广告区域5、最新资讯6、登录注册页联调6.1、网络请求工具类6.2、注册页联调6.3、登录问题分析6.4、本地缓存6.5、共享token6.6、登录页联调6.7、退出登录 1、页面效果 import package:flutter/material.dart…

ctfshow VIP题目限免 源码泄露

根据题目提示是源代码泄露&#xff0c;右键查看页面源代码发现了 flag

移动神器RAX3000M路由器变身家庭云之七:增加打印服务,电脑手机无线打印

系列文章目录&#xff1a; 移动神器RAX3000M路由器变身家庭云之一&#xff1a;开通SSH&#xff0c;安装新软件包 移动神器RAX3000M路由器变身家庭云之二&#xff1a;安装vsftpd 移动神器RAX3000M路由器变身家庭云之三&#xff1a;外网访问家庭云 移动神器RAX3000M路由器不刷固…

《函数基础与内存机制深度剖析:从 return 语句到各类经典编程题详解》

一、问答题 &#xff08;1&#xff09;使用函数的好处是什么&#xff1f; 1.提升代码的复用性 2.提升代码的可维护性 3.增强代码的可读性 4.提高代码的灵活性 5.方便进行单元测试 &#xff08;2&#xff09;如何定义一个函数&#xff1f;如何调用一个函数&#xff1f; 在Pytho…

Python | 使用Matplotlib绘制Swarm Plot(蜂群图)

Swarm Plot&#xff08;蜂群图&#xff09;是一种数据可视化图表&#xff0c;它用于展示分类数据的分布情况。这种图表通过将数据点沿着一个或多个分类变量轻微地分散&#xff0c;以避免它们之间的重叠&#xff0c;从而更好地显示数据的分布密度和分布趋势。Swarm Plot特别适用…

新版本Xmind结合DeepSeek快速生成美丽的思维导图

前言 我的上一篇博客&#xff08;https://quickrubber.blog.csdn.net/article/details/146518898&#xff09;中讲到采用Python编程可以实现和Xmind的互动&#xff0c;并让DeepSeek来生成相应的代码从而实现对内容的任意修改。但是&#xff0c;那篇博客中提到的Xmind有版本的限…

set和map封装

目录 set和map区别 set和map的插入 set和map的实现 修改红黑树的模板参数 修改比较时使用的变量 迭代器的实现 迭代器的定义 *解引用重载 ->成员访问重载 自增重载 重载 封装迭代器 RBTree迭代器封装 封装set迭代器 对set迭代器进行修改 封装map迭代器 修改…