Embedding与EmbeddingBag详解
●🍨 本文为🔗365天深度学习训练营 中的学习记录博客
●🍖 原作者:K同学啊 | 接辅导、项目定制
●🚀 文章来源:K同学的学习圈子
1、Embedding详解
Embedding是Pytorch中最基本的词嵌入操作,TensorFlow中也有相同的函数,功能是一样的。Embedding是将每个离散的词汇映射到一个低维的连续向量空间中,并且保持了词汇直接的语义关系。在Pytorch中,Embedding的输入是一个整数张量,每个整数都代表着一个词汇的索引,输出是一个浮点型的张量,每个浮点数都代表着对应词汇的词嵌入向量。
嵌入层使用随机权重初始化,并将学习数据集中的所有词嵌入。它是一个灵活的层,可以以各种方式使用,如:
- 作为深度学习模型的一部分,其中嵌入与模型本身一起被学习。
- 用于加载训练好的词嵌入模型
嵌入层被定义为网络的第一个隐藏层。
函数原型:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
常用参数:
num_embeddings: #词汇表大小,最大整数index + 1
embedding_dim: #词向量的维度
简单示例,用Embedding将两个句子转换为词嵌入向量
import torch
import torch.nn as nn
vocab_size = 12 #词汇表大小
embedding_dim = 4 #嵌入向量的维度
#创建一个embedding层
embedding = nn.Embedding(vocab_size, embedding_dim)
#假设我们有一个包含两个单词索引的输入序列
input_sequence1 = torch.tensor([1,5,8], dtype = torch.long)
input_sequence2 = torch.tensor([2,4], dtype = torch.long)
#使用Embedding层将输入序列转换为词嵌入
embedded_sequence1 = embedding(input_sequence1)
embedded_sequence2 = embedding(input_sequence2)
print(embedded_sequence1)
print(embedded_sequence2)
输出:
上例中,我们定义了简单的词嵌入模型
- 将大小为12 的词汇表中的每个词映射到了一个4维的向量空间中。
- 输入了两个句子,分别是[1, 5, 8]和[2, 4],每个数字代表着词汇表中的一个词汇的索引。
- 将这两个句子通过Embedding转换为词嵌入向量,并输出结果。
- 结果中,每个句子中的每个词汇都被映射成了4维的向量。
2. EmbeddingBag详解
EmbeddingBag是在Embedding基础上进一步优化的工具。主要优化点在于:它可以直接处理不定长的句子,并可计算句子中所有词汇的词嵌入向量的均值或总和。前者可以简化使用,后者则可即时评估向量生成效果。
在Pytorch中,EmbeddingBag的输入是一个整数张量和一个偏移量张量,每个整数都代表着一个词汇的索引,偏移量则表示句子中每个词汇的位置,输出是一个浮点型的张量。每个浮点数都代表这对应句子的词嵌入向量的均值或总和。
示例:用EmbeddingBag将两个句子转换为词嵌入向量并计算它们的均值。
import torch
import torch.nn as nn
vocab_size = 12 #词汇表大小
embedding_dim = 4 #嵌入向量维度
#创建一个EmbeddingBag层
embedding_bag = nn.EmbeddingBag(vocab_size, embedding_dim, mode = 'mean')
#假设我们有两个不同长度的输入序列
input_sequence1 = torch.tensor([1, 5, 8], dtype = torch.long)
input_sequence2 = torch.tensor([2, 4], dtype = torch.long)
#将两个输入序列拼接在一起,并创建一个偏移张量
input_sequences = torch.cat([input_sequence1, input_sequence2])
offsets = torch.tensor([0,len(input_sequence1)], dtype = torch.long)
#使用EmbeddingBag层计算序列汇总(这里使用平均值)
embedded_bag = embedding_bag(input_sequences, offsets)
print(embedded_bag)
输出:
在该示例中,我们的模型构建步骤如下:
- 定义一个大小为12的词汇表,并将每个词汇映射到一个4维的向量空间中。
- 输入两个句子,分别为[1, 5, 8]和[2, 4],每个数字代表词汇表中的一个词汇的索引。
- 通过EmbeddingBag将每个句子中的每个词汇转换为词嵌入向量,并计算它们的均值。
- 结果表明,每个句子的词嵌入向量的均值都是一个4维的向量。
EmbeddingBag层中的mode参数用于指定如何对每个序列中的嵌入向量进行汇总。常用模式主要有3种:sum、mean、max。模式选择主要取决于具体的任务和数据集。
- 文本分类任务中,通常使用mean模式,因为它可以捕捉到每个序列的平均嵌入,反映出序列的整体含义。
- 序列标注任务中,通常使用sum模式,因为它可以捕捉到每个序列的所有信息,不会丢失任何关键信息。
3. 任务补充
任务要求:
加载
one-hot篇附件中的.txt文件,并使用EmbeddingBag和Embedding完成词嵌入
代码:
# ------文本处理,将句子转换为整数序列 ------
import torch
import jieba
# 确认打开的文本文件路径和文件名
file_name = "D:\Personal Data\Learning Data\DL Learning Data\Embedding.txt"
# 从文件中读取文本行,并替代预定义的Sentences
with open(file_name, "r", encoding = "utf-8") as file:
Context = file.read()
Sentences = Context.split()
print("==== 文本分句: ====\n", Sentences) # 打印核对结果
# 使用jieba.cut()函数逐句进行分词,结果输出为一个列表
tokenized_texts = [list(jieba.lcut(sentence)) for sentence in Sentences]
print("==== 分词结果: ====\n", tokenized_texts) # 打印核对结果
# 构建词汇表
word_index = {}
index_word = {}
for i, word in enumerate(set([word for text in tokenized_texts for word in text])):
word_index[word] = i
index_word[i] = word
print("==== 词汇表: ====\n", word_index) # 打印核对结果
# 将文本转化为整数序列
sequences = [[word_index[word] for word in text] for text in tokenized_texts]
print("==== 文本序列: ====\n",sequences) # 打印核对结果
# 获取词汇表大小, 并+1
vocab_size = len(word_index) + 1
输出:
==== 文本分句: ====
['比较直观的编码方式是采用上面提到的字典序列。例如,对于一个有三个类别的问题,可以用1、2和3分别表示这三个类别。但是,这种编码方式存在一个问题,就是模型可能会错误地认为不同类别之间存在一些顺序或距离关系,而实际上这些关系可能是不存在的或者不具有实际意义的。', '为了避免这种问题,引入了one-hot编码(也称独热编码)。one-hot编码的基本思想是将每个类别映射到一个向量,其中只有一个元素的值为1,其余元素的值为0。这样,每个类别之间就是相互独立的,不存在顺序或距离关系。例如,对于三个类别的情况,可以使用如下的one-hot编码:']
Loading model cost 0.753 seconds.
Prefix dict has been built successfully.
==== 分词结果: ====
[['比较', '直观', '的', '编码方式', '是', '采用', '上面', '提到', '的', '字典', '序列', '。', '例如', ',', '对于', '一个', '有', '三个', '类别', '的', '问题', ',', '可以', '用', '1', '、', '2', '和', '3', '分别', '表示', '这', '三个', '类别', '。', '但是', ',', '这种', '编码方式', '存在', '一个', '问题', ',', '就是', '模型', '可能', '会', '错误', '地', '认为', '不同', '类别', '之间', '存在', '一些', '顺序', '或', '距离', '关系', ',', '而', '实际上', '这些', '关系', '可能', '是', '不', '存在', '的', '或者', '不', '具有', '实际意义', '的', '。'], ['为了', '避免', '这种', '问题', ',', '引入', '了', 'one', '-', 'hot', '编码', '(', '也', '称', '独热', '编码', ')', '。', 'one', '-', 'hot', '编码', '的', '基本', '思想', '是', '将', '每个', '类别', '映射', '到', '一个', '向量', ',', '其中', '只有', '一个', '元素', '的', '值', '为', '1', ',', '其余', '元素', '的', '值', '为', '0', '。', '这样', ',', '每个', '类别', '之间', '就是', '相互', '独立', '的', ',', '不', '存在', '顺序', '或', '距离', '关系', '。', '例如', ',', '对于', '三个', '类别', '的', '情况', ',', '可以', '使用', '如下', '的', 'one', '-', 'hot', '编码', ':']]
==== 词汇表: ====
{'具有': 0, '就是': 1, '元素': 2, '引入': 3, '如下': 4, '思想': 5, '但是': 6, '模型': 7, '0': 8, '、': 9, '对于': 10, '为了': 11, '独立': 12, '其余': 13, '(': 14, '提到': 15, '这样': 16, '三个': 17, '采用': 18, '其中': 19, '表示': 20, '使用': 21, '到': 22, '存在': 23, '和': 24, ',': 25, '一些': 26, '这种': 27, '有': 28, '向量': 29, '例如': 30, '字典': 31, '编码': 32, '或': 33, '会': 34, 'hot': 35, '映射': 36, '比较': 37, '3': 38, '可以': 39, '。': 40, '了': 41, '序列': 42, '将': 43, '情况': 44, '2': 45, '是': 46, '或者': 47, '上面': 48, '这': 49, '编码方式': 50, '用': 51, '避免': 52, '实际意义': 53, '直观': 54, ')': 55, '实际上': 56, '值': 57, '这些': 58, '-': 59, '分别': 60, '而': 61, '相互': 62, '之间': 63, '也': 64, '只有': 65, 'one': 66, '认为': 67, '1': 68, '距离': 69, '问题': 70, '一个': 71, '可能': 72, '独热': 73, '称': 74, '类别': 75, '的': 76, '顺序': 77, '基本': 78, '不同': 79, '关系': 80, ':': 81, '地': 82, '错误': 83, '每个': 84, '不': 85, '为': 86}
==== 文本序列: ====
[[37, 54, 76, 50, 46, 18, 48, 15, 76, 31, 42, 40, 30, 25, 10, 71, 28, 17, 75, 76, 70, 25, 39, 51, 68, 9, 45, 24, 38, 60, 20, 49, 17, 75, 40, 6, 25, 27, 50, 23, 71, 70, 25, 1, 7, 72, 34, 83, 82, 67, 79, 75, 63, 23, 26, 77, 33, 69, 80, 25, 61, 56, 58, 80, 72, 46, 85, 23, 76, 47, 85, 0, 53, 76, 40], [11, 52, 27, 70, 25, 3, 41, 66, 59, 35, 32, 14, 64, 74, 73, 32, 55, 40, 66, 59, 35, 32, 76, 78, 5, 46, 43, 84, 75, 36, 22, 71, 29, 25, 19, 65, 71, 2, 76, 57, 86, 68, 25, 13, 2, 76, 57, 86, 8, 40, 16, 25, 84, 75, 63, 1, 62, 12, 76, 25, 85, 23, 77, 33, 69, 80, 40, 30, 25, 10, 17, 75, 76, 44, 25, 39, 21, 4, 76, 66, 59, 35, 32, 81]]
使用EmbeddingBag进行词嵌入:
# 创建一个EmbeddingBag层
embedding_dim = 100 # 定义嵌入向量的维度
embedding_bag = torch.nn.EmbeddingBag(vocab_size, embedding_dim, mode = "mean")
# 将多个输入序列拼接在一起,并创建一个偏移量张量
# 首先需要创建空张量和空列表
input = torch.tensor([], dtype = torch.long)
offset = []
# 逐句处理,进行张量拼接、向列表中添加偏移量数值
for sequence in sequences:
offset.append(len(input))
input = torch.cat([input, torch.tensor(sequence, dtype= torch.long)])
# 将列表形式的偏移量转换为张量形式,用于embedding_bag()函数的输入
offset = torch.tensor(offset, dtype = torch.long)
# 检查序列张量拼接和索引生成结果
print("-"*80)
print(offset)
print(input)
print("-"*80)
# 使用Embedding层将输入序列转换为词嵌入
embedded_bag = embedding_bag(input, offset)
# 打印输出结果
print("词嵌入结果: \n", embedded_bag)
输出:
tensor([ 0, 75])
tensor([37, 54, 76, 50, 46, 18, 48, 15, 76, 31, 42, 40, 30, 25, 10, 71, 28, 17,
75, 76, 70, 25, 39, 51, 68, 9, 45, 24, 38, 60, 20, 49, 17, 75, 40, 6,
25, 27, 50, 23, 71, 70, 25, 1, 7, 72, 34, 83, 82, 67, 79, 75, 63, 23,
26, 77, 33, 69, 80, 25, 61, 56, 58, 80, 72, 46, 85, 23, 76, 47, 85, 0,
53, 76, 40, 11, 52, 27, 70, 25, 3, 41, 66, 59, 35, 32, 14, 64, 74, 73,
32, 55, 40, 66, 59, 35, 32, 76, 78, 5, 46, 43, 84, 75, 36, 22, 71, 29,
25, 19, 65, 71, 2, 76, 57, 86, 68, 25, 13, 2, 76, 57, 86, 8, 40, 16,
25, 84, 75, 63, 1, 62, 12, 76, 25, 85, 23, 77, 33, 69, 80, 40, 30, 25,
10, 17, 75, 76, 44, 25, 39, 21, 4, 76, 66, 59, 35, 32, 81])
--------------------------------------------------------------------------------
词嵌入结果:
tensor([[ 5.5692e-02, -7.1316e-02, -6.6372e-02, -1.6054e-02, -1.5890e-02,
9.3022e-03, -4.8811e-02, -4.7034e-02, -8.3843e-03, -1.9857e-02,
1.1352e-02, 1.4714e-01, 1.2782e-01, 1.2540e-01, 2.2769e-01,
-2.1690e-01, 1.2728e-02, -1.9718e-01, 4.8604e-02, -8.8129e-02,
1.4818e-01, -3.2952e-01, -3.2805e-02, -1.6356e-01, 1.1112e-01,
-1.0095e-01, -5.1578e-02, -1.1523e-01, 3.2936e-01, -3.6964e-01,
-8.3445e-02, -6.9567e-02, 1.1665e-01, 1.6558e-01, 2.6067e-01,
-1.4318e-01, 6.1249e-02, -3.4959e-02, -4.9525e-02, 5.4743e-02,
-3.6878e-02, 6.7813e-02, 7.3439e-02, -6.0280e-03, 7.6804e-02,
3.0789e-02, 1.6987e-01, -6.2407e-02, 8.7294e-02, 1.1892e-01,
-1.7377e-01, -2.3485e-04, 4.0972e-02, 1.6278e-02, 1.0198e-01,
-1.4946e-01, -2.4754e-01, -1.2399e-01, 4.3227e-02, 4.9916e-02,
-2.0984e-01, 1.5504e-01, -1.7622e-01, 1.1868e-01, 1.7071e-01,
-2.5039e-02, 1.0324e-01, -1.2662e-02, 2.5191e-01, 8.0460e-02,
7.6614e-02, -2.7530e-01, 6.2472e-02, 1.6579e-01, 8.8133e-02,
1.3551e-01, -2.7536e-01, 3.3397e-02, -1.5716e-01, -2.0973e-01,
-1.2795e-01, -2.1313e-01, -2.1758e-02, 1.2416e-01, 1.7992e-01,
-3.7501e-01, -4.4248e-02, 1.2105e-01, -3.8175e-02, 2.2732e-01,
-6.5679e-02, 8.8733e-02, -1.3111e-01, 5.4078e-02, 4.8781e-03,
2.1772e-01, 1.9874e-01, 2.0958e-03, -9.0118e-02, -2.5515e-01],
[ 1.3323e-01, 3.4331e-02, -8.2839e-02, 1.2361e-01, -2.8169e-01,
-4.4598e-02, -3.2946e-02, -1.4866e-01, 1.6180e-01, -2.7819e-01,
-6.4126e-02, -4.7323e-02, 1.3915e-02, 1.6032e-01, 1.3364e-01,
-7.4780e-02, -1.5139e-01, -3.3980e-01, 1.9030e-01, 1.9205e-02,
2.0883e-01, -2.5076e-01, -5.6575e-02, 6.3742e-02, 1.1623e-01,
4.9935e-03, 1.2527e-01, -4.2433e-01, 7.7243e-02, -1.4197e-01,
-6.0625e-02, -1.9400e-01, 2.1380e-02, 9.1898e-02, 2.0745e-01,
-3.2166e-01, 6.8514e-02, -1.3532e-02, -1.5256e-01, 5.6108e-02,
-1.3043e-01, 4.6693e-02, 4.1241e-02, 2.1686e-01, 2.1685e-01,
1.0330e-01, 1.9492e-01, 4.1955e-02, 1.1217e-01, 2.3970e-01,
-8.2919e-02, -2.8210e-01, 1.9147e-01, -4.3560e-02, -3.4606e-01,
6.1696e-02, -2.7092e-01, -2.3062e-01, 4.1380e-02, 1.0824e-01,
-2.8622e-01, 3.2143e-03, -1.1270e-01, 4.5128e-02, 1.0355e-01,
9.4424e-03, -1.0250e-02, -3.4805e-01, 2.8352e-01, 1.2028e-01,
-4.0606e-02, -6.6366e-02, 8.5532e-02, 3.8679e-01, 2.7721e-01,
2.1182e-01, -3.1888e-01, -5.2964e-02, 7.4347e-03, -4.4626e-02,
-5.2423e-02, -2.4242e-01, 3.4623e-02, -1.0850e-01, 6.3444e-02,
-1.3703e-01, 1.8642e-01, 1.6533e-01, -1.4030e-01, 1.9591e-01,
-3.6116e-02, 2.3281e-01, -8.3522e-03, 1.1427e-01, -4.5963e-03,
1.5077e-01, -2.7178e-02, -3.9486e-02, -1.0071e-01, -1.1353e-01]],
grad_fn=<EmbeddingBagBackward0>)