nn.Embedding()的原理:
定义一个Embedding:
embeddings = nn.Embedding(num_embeddings=10, embedding_dim=3)
vocab_size : 10
输出维度为: 3
假定输入inputs如下:
inputs = torch.tensor([
[1,3,6, 8],
[9,1,3,5]
],dtype=torch.long)
max_num 为 9
vocab_size: 10
下面讲述他的原理:
首先假定inputs.shape = (batch_size , sentence_len)
即(2,4)
现在我们用
first = F.one_hot(inputs,num_classes=10)
去做第一次one_hot的输出,即shape = (2,4,10)
10 代表 v
即shape = (2,4 , v)
embeddings的weight.shape = (10,3) =>>> (v,s)
那么怎么得到(2,4,s)呢?
torch.matmul(torch.tensor(first, dtype=torch.float),embeddings.weight)
即可得到embeddings(inputs)相同的结果!!!
下面为代码: