word2vector训练代码详解

news2024/9/27 5:03:30

目录

1.代码实现

2.知识点 


 

1.代码实现

#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集  ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5  
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)

#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)  
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape   #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])

embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法 

#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
    """
    embed_v:表示对中心词进行embedding层
    embed_u:对上下文词进行embedding层 
    """
    v = embed_v(center)                 #中心词的词向量表达
    u = embed_u(contexts_and_negatives) #上下文词的词向量表达
    #用中心词来预测上下文词
    #u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法
    pred = torch.bmm(v, u.permute(0, 2, 1))  #矩阵乘法(bmm三维乘法),不用管batch_size维度
    return pred
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed)
tensor([[[3.1980, 3.1980, 3.1980, 3.1980]],

        [[3.1980, 3.1980, 3.1980, 3.1980]]], grad_fn=<BmmBackward0>)
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape

 torch.Size([2, 1, 4])

#带掩码的二元交叉熵损失
class SigmoidBCELoss(nn.Module):
    def __init__(self):
        super().__init__()  #直接继承父类的初始化属性和方法
    
    def forward(self, inputs, target, mask=None):
        #nn.functional.binary_cross_entropy_with_logits表示返回的不是转化后的概率,是原始计算的数据结果
        #weight=mask权重将掩码带上
        #reduction='none'表示不将计算结果聚合,算损失时(默认聚合)
        out = nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction='none')
        return out.mean(dim=1)  #计算结果是二维的,在索引1维度上聚合求平均
loss = SigmoidBCELoss()
[[1.1, -2.2, 3.3, -4.4]] * 2
[[1.1, -2.2, 3.3, -4.4], [1.1, -2.2, 3.3, -4.4]]
torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2).shape

 torch.Size([2, 4])

#假设数据测试
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
#mask每一行都有4个数值,所以* mask.shape[1]=4
#但是mask中的数值0表示权重,是补充步长的,不重要,需要计算有效序列的损失平均值,所以 / mask.sum(axis=1)
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)

 tensor([0.9352, 1.8462])

#初始化模型参数,定义两个嵌入层
#一开始,embed_weights会标准正态分布的数据初始化
#两个embedding层的参数不一样,不能重复使用,需要初始化定义两个
embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size),
                    nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size))

 

#定义训练过程
def train(net, data_iter, lr, num_epochs, device=dltools.try_gpu()):
    #修改embedding层的初始化方法,使用nn.init.xavier_uniform_初始化embed.weight权重,在NLP中不使用标准正态分布的额数据初始化权重
    def init_weights(m):
        if type(m) == nn.Embedding:
            nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)  
    net = net.to(device)
    #设置梯度下降的优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    #设置绘制可视化的动图(epoch——loss)
    animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs])
    
    #设置累加
    metric = dltools.Accumulator(2)   #2种数据需要累加
    for epoch in range(num_epochs):  #遍历训练次数
        #设置计时器, 赋值批次数量
        timer, num_batches = dltools.Timer(), len(data_iter)    #data_iter是分好批次的数据集,长度就是批次数量num_batches
        for i, batch in enumerate(data_iter):   #i是索引, batch是取出的一批批数据
            #梯度清零
            optimizer.zero_grad()
            #接收中心词, 上下文词_噪声词, 掩码, 标记目标值 
            center, context_negative, mask, label = [data.to(device) for data in batch]
            #调用skip_gram模型预测
            pred = skip_gram(center, context_negative, embed_v=net[0], embed_u=net[1])
            #计算损失
            l = loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.shape[1] * mask.sum(dim=1)
            #用loss反向传播  ,loss先sum()聚合变成标量(合并成一个数值), 只有标量才能反向传播
            l.sum().backward()
            #梯度更新
            optimizer.step()
            #累加
            metric.add(l.sum(), l.numel())   #l.sum()数值求和累加, l.numel()数量累加
            #   %  取余数      
            #  //  商向下取整
            #迭代到总数据量的5%的倍数时 或者 处理到最后一批数据时,执行下面操作
            #  i+1是因为i是从0开始遍历的
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:  
                #epoch + (i+1) / num_batches当前迭代次数占整个数据集的比例
                animator.add(epoch + (i+1) / num_batches, (metric[0] / metric[1]))
    print(f'loss {metric[0] / metric[1]:.3f}', f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')      
lr, num_epochs = 0.002, 50
train(net, data_iter, lr, num_epochs)

#如果能够找到词的近义词, 就说明训练的不错
def get_similar_tokens(query_token, k, embed):
    """
    query_token:需要预测的词
    k:最高相似度的词数量
    embed:embedding层的哪一层
    """
    #获取词向量权重    (词向量权重*词的one_hot编码,就是词向量)
    W = embed.weight.data
    print(f'W的shape:{W.shape}')
    x = W[vocab[query_token]]     #embedding层是按照索引查表查词对应的权重-->优点
    print(f'x的shape:{x.shape}')
    #计算余弦相似度
    #torch.mv两个向量的点乘
    cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)
    print(f'cos的shape:{cos.shape}')
    #排序选择前k个对应的索引
    topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')
    for i in topk[1:]:   #排除query_token他本身,自己与自己余弦相似度最高
        print(f'cosine sim={float(cos[i]):.3f}:{vocab.to_tokens(i)}')
        
get_similar_tokens('food', 3, net[0])

 

W的shape:torch.Size([6719, 100])
x的shape:torch.Size([100])
cos的shape:torch.Size([6719])
cosine sim=0.430:feed
cosine sim=0.418:precious
cosine sim=0.412:drink

2.知识点 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2168880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

堆的数组实现

目录 一、堆 二叉树的顺序结构 堆的概念及结构 1.概念 2.堆的分类 (1)大堆 (2)小堆 二、利用数组(顺序结构)实现堆的过程 1.利用数组实现堆的思路 2.堆是用数组实现的&#xff0c;在数组中通过双亲找自己左右孩子、通过左右孩子找自己双亲的思路 2.1.思路 2.2.孩子与…

认知杂谈84《菜鸟的自我修炼:知易行难与行难知易》

内容摘要&#xff1a; 理解与行动之间的差距是日常生活的常见挑战。"知易行难"体现在理解简单但执行困难&#xff0c;例如知道蔬菜有益但难以坚持食用。而"行难知易"则是开始时困难但后来容易的任务&#xff0c;如学习骑自行车。 这种差异源于心理惰性和习…

使用 Llama-index 实现的 Agentic RAG-Router Query Engine

前言 你是否也厌倦了我在博文中经常提到的老式 RAG(Retrieval Augmented Generation | 检索增强生成) 系统&#xff1f;反正我是对此感到厌倦了。但我们可以做一些有趣的事情&#xff0c;让它更上一层楼。接下来就跟我一起将 agents 概念引入传统的 RAG 工作流&#xff0c;重新…

OnlyOffice 打开文档时提示下载失败

OnlyOffice 下载失败问题 问题概述 OnlyOffice前端界面出现“下载失败” 问题定位&#xff08;0&#xff1a;docker内不能够访问&#xff09; 很常见的一种情况是后端服务地址错误&#xff0c;在docker内无法访问。 请在docker容器中确定这个地址是可以访问的&#xff0c;鉴…

electron 设置界面右下角打开

功能需求场景 写一个可以下载各种平台的小工具&#xff0c;需要右下角打开方便做其它事情 实现基础 要在屏幕的右下角设置窗口&#xff0c;可以调整mainWindow的创建参数&#xff0c;特别是通过使用x和y坐标来定位窗口 &#xff1b; 需要获取屏幕的尺寸&#xff0c;并据此计算…

不透明物体的投射和接收阴影

1、Fallback的作用 新建一个材质球&#xff0c;将其的Shader设置为之前编写的多种光源综合实现Shader 并将该材质球赋值给较大的立方体使用&#xff0c;我们会发现该立方体不再投射阴影也不再接受阴影 &#xff08;1&#xff09;不投射阴影的原因 该Shader中没有LightMo…

Rust编程的if选择语句

【图书介绍】《Rust编程与项目实战》-CSDN博客 《Rust编程与项目实战》(朱文伟&#xff0c;李建英)【摘要 书评 试读】- 京东图书 (jd.com) Rust编程与项目实战_夏天又到了的博客-CSDN博客 Rust语言实现选择结构时&#xff0c;根据某种条件的成立与否而采用不同的程序段进行…

【Kubernetes】日志平台EFK+Logstash+Kafka【实战】

一&#xff0c;环境准备 &#xff08;1&#xff09;下载镜像包&#xff08;共3个&#xff09;&#xff1a; elasticsearch-7-12-1.tar.gz fluentd-containerd.tar.gz kibana-7-12-1.tar.gz &#xff08;2&#xff09;在node节点导入镜像&#xff1a; ctr -nk8s.io images i…

解决sortablejs+el-table表格内限制回撤和拖拽回撤失败问题

应用场景&#xff1a; table内同一类型可拖拽&#xff0c;不支持不同类型拖拽&#xff08;主演可拖拽交换位置&#xff0c;非主演和主演不可交换位置&#xff09;,类型不同拖拽效果需还原&#xff0c;试了好几次el-table数据更新了&#xff0c;但是表格样式和数据不能及时保持…

Java面试题之JVM面试题

JVM 的主要作用是什么&#xff1f; JVM 就是 Java Virtual Machine&#xff08;Java虚拟机&#xff09;的缩写&#xff0c;JVM 屏蔽了与具体操作系统平台相关的信息&#xff0c;使 Java 程序只需生成在 Java 虚拟机上运行的目标代码 &#xff08;字节码&#xff09;&#xff0…

uniapp 常用高度状态栏,导航栏,tab栏,底部安全高度

实际效果 使用 //使用 let posConfig this.getPosConfig(); // 传false返回值为 px大小 console.log(posConfig.safeBottomH) // 入参 是否转换为rpxgetPosConfig(toRpx true) {const systemInfo uni.getSystemInfoSync();// #ifdef MPconst menuButtonInfo uni.getMenuBu…

Hello Algorithm:Capture 1,2 初识算法

大家好 :) 自学完sklearn的基本使用后&#xff0c;颇感无趣。虽有阅文几篇&#xff0c;却无所获。遂于24年9月26日决习hello algorithm。 &#xff1a;&#xff09; 好了&#xff0c;不开玩笑了。其实开设这篇专栏我也不知道有没有什么意义。其实是因为最近在读TaskWeaver&…

关于最小二乘法

最小二乘法的核心思想简单而优雅&#xff1a;我们希望找到一条最佳的曲线&#xff0c;使其尽可能贴近所有的数据点。想象一下&#xff0c;当你在画布上描绘一条线&#xff0c;目标是让这条线与点的距离最小。数学上&#xff0c;这可以表示为&#xff1a; 在这个公式中&#xff…

Eclipse Memory Analyzer (MAT)提示No java virtual machine was found ...解决办法

1&#xff0c;下载mat后安装&#xff0c;打开时提示 jdk版本低&#xff0c;需要升级到jdk17及以上版本&#xff0c;无奈就下载了jdk17&#xff0c;结果安装后提示没有jre环境&#xff0c;然后手动生成jre目录&#xff0c;命令如下&#xff1a; 进入jdk17目录&#xff1a;执行&…

SpringBoot的基础(自动配置)

SpringBootApplication注解 是一个组合注解&#xff0c;其中EnableAutoConfiguration让SpringBoot根据类路径中的jar包依赖为当前项目进行自动配置 例如&#xff1a;添加了spring-boot-starter-web依赖&#xff0c;会自动添加Tomcat和SpringMVC的依赖 添加了spring-boot-start…

【UE5】将2D切片图渲染为体积纹理,最终实现使用RT实时绘制体积纹理【第四篇-着色器投影-接收阴影部分】

上一章中实现了体积渲染的光照与自阴影&#xff0c;那我们这篇来实现投影 回顾 勘误 在开始本篇内容之前&#xff0c;我已经对上一章中的内容的错误进行了修改。为了确保不会错过这些更正&#xff0c;同时也避免大家重新阅读一遍&#xff0c;我将在这里为大家演示一下修改的…

LeetCode - 850 矩形面积 II

题目来源 850. 矩形面积 II - 力扣&#xff08;LeetCode&#xff09; 题目描述 给你一个轴对齐的二维数组 rectangles 。 对于 rectangle[i] [x1, y1, x2, y2]&#xff0c;其中&#xff08;x1&#xff0c;y1&#xff09;是矩形 i 左下角的坐标&#xff0c; (xi1, yi1) 是该…

灵当CRM index.php接口SQL注入漏洞复现 [附POC]

文章目录 灵当CRM index.php接口SQL注入漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现 0x06 修复建议 灵当CRM index.php接口SQL注入漏洞复现 [附POC] 0x01 前言 免责声明&#xff1a;请勿利用文章内的相关技…

数据治理003-数据域

数据仓库是面向主题&#xff08;数据综合、归类并进行分析利用的抽象&#xff09;的应用。 数据仓库模型设计除横向的分层外&#xff0c;通常也需要根据业务情况进行纵向划分数据域。数据域是联系较为紧密的数据主题的集合&#xff0c;通常是根据业务类别、数据来源、数据用途…

001、视频添加字幕

1. 腾讯智影 (可用) https://zenvideo.qq.com/ 1.1 操作步骤 https://zenvideo.qq.com/ https://zenvideo.qq.com/my/material?typeexport 上传资源 自动字幕识别 修改字幕 下载字幕 上传字幕 https://zenvideo.qq.com/my/material?typeexport 2. 秒剪–手机版app &a…