在 PyTorch 中,nn.Embedding
层(即 model.user_embedding
)除了 .weight
这个核心属性外,还有其他属性和方法。以下是完整的解析:
1. 主要属性
(1) weight
(核心参数)
- 作用:存储所有嵌入向量的可训练权重矩阵。
- 形状:
(num_embeddings, embedding_dim)
。 - 示例:
print(model.user_embedding.weight.shape) # 输出:torch.Size([3, 4])
(2) num_embeddings
- 作用:返回嵌入向量的总数(即用户/物品的数量)。
- 示例:
print(model.user_embedding.num_embeddings) # 输出:3
(3) embedding_dim
- 作用:返回每个嵌入向量的维度。
- 示例:
print(model.user_embedding.embedding_dim) # 输出:4
(4) padding_idx
(可选)
- 作用:如果设置了
padding_idx
,则对应的嵌入向量会被强制设为 0 且不参与训练。 - 示例:
# 初始化时设置 padding_idx=0 self.user_embedding = nn.Embedding(3, 4, padding_idx=0) print(model.user_embedding.padding_idx) # 输出:0 print(model.user_embedding.weight[0]) # 输出:tensor([0., 0., 0., 0.], grad_fn=<SelectBackward>)
2. 主要方法
(1) forward(input)
- 作用:根据输入的 ID 返回对应的嵌入向量。
- 示例:
input_ids = torch.tensor([0, 1, 2]) # 查询用户 0、1、2 的向量 embeddings = model.user_embedding(input_ids) # 返回 shape (3, 4)
(2) reset_parameters()
- 作用:重新随机初始化权重(通常在训练前调用)。
- 内部逻辑:默认使用均匀分布 U ( − k , k ) U(-\sqrt{k}, \sqrt{k}) U(−k,k),其中 k = 1 embedding_dim k = \frac{1}{\text{embedding\_dim}} k=embedding_dim1。
- 示例:
model.user_embedding.reset_parameters()
(3) extra_repr()
- 作用:返回层的额外信息(用于
print
时显示)。 - 示例:
print(model.user_embedding.extra_repr()) # 输出:'num_embeddings=3, embedding_dim=4'
3. 其他底层属性(一般无需直接操作)
_parameters
:存储所有可训练参数(包括weight
)。_buffers
:存储非可训练参数(如 BatchNorm 的 running_mean)。training
:布尔值,表示是否处于训练模式。
4. 完整属性/方法列表
可以通过 dir()
查看所有属性和方法:
print(dir(model.user_embedding))
输出示例:
['__class__', '__delattr__', '__dir__', ..., 'weight', 'num_embeddings', 'embedding_dim', 'padding_idx', 'forward', 'reset_parameters']
5. 关键总结
属性/方法 | 用途 | 示例值/调用方式 |
---|---|---|
.weight | 核心权重矩阵 | shape=(3, 4) |
.num_embeddings | 嵌入向量的总数(用户数) | 3 |
.embedding_dim | 每个向量的维度 | 4 |
.padding_idx | 指定填充索引(可选) | None 或 0 |
.forward(input) | 查询嵌入向量 | model.user_embedding([0, 1]) |
.reset_parameters() | 重新初始化权重 | model.user_embedding.reset_parameters() |
6. 常见问题
Q:如何修改嵌入向量?
- 直接操作
.weight
:# 将用户 0 的向量置零 model.user_embedding.weight.data[0] = torch.zeros(4)
Q:如何冻结嵌入层?
- 禁用梯度:
model.user_embedding.weight.requires_grad = False
Q:padding_idx
和普通索引有什么区别?
padding_idx
对应的向量会固定为 0,且不参与梯度更新。
掌握这些属性和方法后,你可以更灵活地操作嵌入层! 🚀