文章目录
-
- 1. nn.Embedding的简单介绍
-
- 1.1 基本用法
- 1.2 示例代码
- 1.3 注意事项
- 2. 通俗的理解num_embeddings和embedding_dim
-
- 2.1 num_embeddings
- 2.2 embedding_dim
- 2.3 使用场景举例
- 结合示例
1. nn.Embedding的简单介绍
nn.Embedding
是 PyTorch 中的一个模块,用于创建一个嵌入层
。这个层的主要作用是将离散的数值(通常是代表单词的索引)映射到一个连续的
、固定大小
的向量空间,这些向量也称为嵌入向量
。在自然语言处理和其他类型的机器学习任务中,嵌入层是非常常用的,它可以帮助模型理解和处理类别型输入数据。
1.1 基本用法
在 PyTorch 中,nn.Embedding
需要两个主要的参数:num_embeddings
和 embedding_dim
:
num_embeddings
(整数): 嵌入层中的嵌入向量的数量,通常是词汇表的大小。embedding_dim
(整数): 每个嵌入向量的维度。
1.2 示例代码
以下是一个简单的使用 nn.Embedding
的例子:
import torch
impo