Python Skip-Gram代码实战,Skip-Gram代码超简单讲解和步骤拆解,Word2vec代码构建思路,Skip-Gram代码实例,模板套用

news2024/11/20 14:20:00

1. Skip-Gram介绍

        Skip-gram模型是Word2Vec模型的一种训练方法,它的目标是通过目标词预测上下文词。Skip-gram模型通过神经网络结构来学习每个单词的向量表示。

        在Skip-gram模型中,每个单词被表示为一个固定维度的向量,该向量称为嵌入向量或词向量。模型通过对训练语料中的每个中心词进行预测,来学习得到这些词向量。

        训练过程中,Skip-gram模型的输入是一个中心词,目标是预测该词周围窗口内的上下文词。例如,给定一个句子"the cat sat on the mat"和窗口大小为1,中心词"sat"会被输入到模型中,而目标是预测"the"和"on"这两个上下文词。为了完成这个预测任务,模型通过调整词向量的参数来优化预测结果。

        Skip-gram模型的训练目标是最大化预测上下文词的概率。具体而言,模型通过在中心词和上下文词之间计算余弦相似度,将预测问题转化为一个二分类问题。通过多次迭代训练,模型会学习到每个单词的向量表示,使得上下文语境相似的单词在向量空间中的距离更近。

        Skip-gram模型具有以下优势:能够处理大规模的语料库、能捕捉到词之间的语义关系、词向量具有相对较低的维度、可以处理生僻词等。因此,Skip-gram模型被广泛应用于自然语言处理领域的词嵌入任务中。

2. Skip-Gram代码实战

2.1 定义一个句子列表,后面会用这些句子来训练 CBOW 和 Skip-Gram 模型

这个其实和CBOW的处理过程一样的

# 定义一个句子列表,后面会用这些句子来训练 CBOW 和 Skip-Gram 模型
sentences = ["Kage is Teacher", "Mazong is Boss", "Niuzong is Boss",
             "Xiaobing is Student", "Xiaoxue is Student",]
# 将所有句子连接在一起,然后用空格分隔成多个单词
words = ' '.join(sentences).split()
# 构建词汇表,去除重复的词
word_list = list(set(words))
# 创建一个字典,将每个词映射到一个唯一的索引
word_to_idx = {word: idx for idx, word in enumerate(word_list)}
# 创建一个字典,将每个索引映射到对应的词
idx_to_word = {idx: word for idx, word in enumerate(word_list)}
voc_size = len(word_list) # 计算词汇表的大小
print(" 词汇表:", word_list) # 输出词汇表
print(" 词汇到索引的字典:", word_to_idx) # 输出词汇到索引的字典
print(" 索引到词汇的字典:", idx_to_word) # 输出索引到词汇的字典
print(" 词汇表大小:", voc_size) # 输出词汇表大小

2.2 生成 Skip-Gram 训练数据

# 生成 Skip-Gram 训练数据
def create_skipgram_dataset(sentences, window_size=2):
    data = [] # 初始化数据
    for sentence in sentences: # 遍历句子
        sentence = sentence.split()  # 将句子分割成单词列表
        for idx, word in enumerate(sentence):  # 遍历单词及其索引
            # 获取相邻的单词,将当前单词前后各 N 个单词作为相邻单词
            for neighbor in sentence[max(idx - window_size, 0): 
                        min(idx + window_size + 1, len(sentence))]:
                if neighbor != word:  # 排除当前单词本身
                    # 将相邻单词与当前单词作为一组训练数据
                    data.append((neighbor, word))
    return data
# 使用函数创建 Skip-Gram 训练数据
skipgram_data = create_skipgram_dataset(sentences)
# 打印未编码的 Skip-Gram 数据样例(前 3 个)
print("Skip-Gram 数据样例(未编码):", skipgram_data[:3])

2.3 定义 One-Hot 编码函数

# 定义 One-Hot 编码函数
import torch # 导入 torch 库
def one_hot_encoding(word, word_to_idx):    
    tensor = torch.zeros(len(word_to_idx)) # 创建一个长度与词汇表相同的全 0 张量  
    tensor[word_to_idx[word]] = 1  # 将对应词的索引设为 1
    return tensor  # 返回生成的 One-Hot 向量
# 展示 One-Hot 编码前后的数据
word_example = "Teacher"
print("One-Hot 编码前的单词:", word_example)
print("One-Hot 编码后的向量:", one_hot_encoding(word_example, word_to_idx))
# 展示编码后的 Skip-Gram 训练数据样例
print("Skip-Gram 数据样例(已编码):", [(one_hot_encoding(context, word_to_idx), 
          word_to_idx[target]) for context, target in skipgram_data[:3]])

2.4 定义 Skip-Gram 类

# 定义 Skip-Gram 类
import torch.nn as nn # 导入 neural network
class SkipGram(nn.Module):
    def __init__(self, voc_size, embedding_size):
        super(SkipGram, self).__init__()
        # 从词汇表大小到嵌入层大小(维度)的线性层(权重矩阵)
        self.input_to_hidden = nn.Linear(voc_size, embedding_size, bias=False)  
        # 从嵌入层大小(维度)到词汇表大小的线性层(权重矩阵)
        self.hidden_to_output = nn.Linear(embedding_size, voc_size, bias=False)  
    def forward(self, X): # 前向传播的方式,X 形状为 (batch_size, voc_size)      
         # 通过隐藏层,hidden 形状为 (batch_size, embedding_size)
        hidden = self.input_to_hidden(X) 
        # 通过输出层,output_layer 形状为 (batch_size, voc_size)
        output = self.hidden_to_output(hidden)  
        return output    
embedding_size = 2 # 设定嵌入层的大小,这里选择 2 是为了方便展示
skipgram_model = SkipGram(voc_size, embedding_size)  # 实例化 Skip-Gram 模型
print("Skip-Gram 模型:", skipgram_model)

2.5 训练 Skip-Gram 类

# 训练 Skip-Gram 类
learning_rate = 0.001 # 设置学习速率
epochs = 1000 # 设置训练轮次
criterion = nn.CrossEntropyLoss()  # 定义交叉熵损失函数
import torch.optim as optim # 导入随机梯度下降优化器
optimizer = optim.SGD(skipgram_model.parameters(), lr=learning_rate)  
# 开始训练循环
loss_values = []  # 用于存储每轮的平均损失值
for epoch in range(epochs):
    loss_sum = 0 # 初始化损失值
    for context, target in skipgram_data:        
        X = one_hot_encoding(target, word_to_idx).float().unsqueeze(0) # 将中心词转换为 One-Hot 向量  
        y_true = torch.tensor([word_to_idx[context]], dtype=torch.long) # 将周围词转换为索引值 
        y_pred = skipgram_model(X)  # 计算预测值
        loss = criterion(y_pred, y_true)  # 计算损失
        loss_sum += loss.item() # 累积损失
        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
    if (epoch+1) % 100 == 0: # 输出每 100 轮的损失,并记录损失
      print(f"Epoch: {epoch+1}, Loss: {loss_sum/len(skipgram_data)}")  
      loss_values.append(loss_sum / len(skipgram_data))
# 绘制训练损失曲线
import matplotlib.pyplot as plt # 导入 matplotlib
# 绘制二维词向量图
plt.rcParams["font.family"]=['SimHei'] # 用来设定字体样式
plt.rcParams['font.sans-serif']=['SimHei'] # 用来设定无衬线字体样式
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
plt.plot(range(1, epochs//100 + 1), loss_values) # 绘图
plt.title(' 训练损失曲线 ') # 图题
plt.xlabel(' 轮次 ') # X 轴 Label
plt.ylabel(' 损失 ') # Y 轴 Label
plt.show() # 显示图

2.6 输出 Skip-Gram 习得的词嵌入 

# 输出 Skip-Gram 习得的词嵌入
print("Skip-Gram 词嵌入:")
for word, idx in word_to_idx.items(): # 输出每个词的嵌入向量
    print(f"{word}: {skipgram_model.input_to_hidden.weight[:,idx].detach().numpy()}")

 

2.7 向量可视化

fig, ax = plt.subplots() 
for word, idx in word_to_idx.items():
    # 获取每个单词的嵌入向量
    vec = skipgram_model.input_to_hidden.weight[:,idx].detach().numpy() 
    ax.scatter(vec[0], vec[1]) # 在图中绘制嵌入向量的点
    ax.annotate(word, (vec[0], vec[1]), fontsize=12) # 点旁添加单词标签
plt.title(' 二维词嵌入 ') # 图题
plt.xlabel(' 向量维度 1') # X 轴 Label
plt.ylabel(' 向量维度 2') # Y 轴 Label
plt.show() # 显示图

3. 总结

Word2Vec是一种用于学习词向量的算法模型,它能够将单词转换为密集的向量表示,并捕捉单词之间的语义关系。Word2Vec模型由Google于2013年提出,是一种基于神经网络的词嵌入技术。

Word2Vec模型包括两种主要的训练方法:Skip-gram和CBOW。Skip-gram模型的目标是通过目标词预测上下文词,而CBOW模型的目标是通过上下文词预测目标词。这两种模型均采用神经网络结构,在大规模文本语料上进行训练,学习得到每个单词的向量表示。

Word2Vec的核心思想是通过单词在上下文中的分布来学习单词的语义信息。具体而言,相似上下文中的单词会拥有相似的词向量表示,这样就能够捕捉到单词之间的语义关系。通过将单词表示为稠密的向量,Word2Vec模型可以表示单词之间的相似度,进而应用于词义相似度计算、文本分类、语言建模等多个自然语言处理任务中。

Word2Vec模型的训练速度快、效果好,因此在自然语言处理领域得到了广泛的应用。它为计算机更好地理解和处理自然语言提供了有效的工具,被认为是自然语言处理领域的重要突破之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1475346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AStar算法(大物件寻路)

前言 A星(物件大小为一格)寻路,都很熟悉了吧,网上源码一堆,随便抄; 这章需要讲述 大物件的A星寻路,何为大物件,就是 比如 物件 为4个格子; 这样,原来的A星 没法直接用了,必须得改装…

【Java程序员面试专栏 算法思维】四 高频面试算法题:回溯算法

一轮的算法训练完成后,对相关的题目有了一个初步理解了,接下来进行专题训练,以下这些题目就是汇总的高频题目,本篇主要聊聊回溯算法,主要就是排列组合问题,所以放到一篇Blog中集中练习 题目关键字解题思路时间空间岛屿数量网格搜索分别向上下左右四个方向探索,遇到海洋…

微信小程序引入Vant插件

Vant官网:Vant Weapp - 轻量、可靠的小程序 UI 组件库 先查看官网的版本 新建一个package.json页面,代码写上:(我先执行的npm安装没出package页面,所以先自己创建了一个才正常) {"dependencies"…

【Spring底层原理高级进阶】基于Spring Boot和Spring WebFlux的实时推荐系统的核心:响应式编程与 WebFlux 的颠覆性变革

🎉🎉欢迎光临🎉🎉 🏅我是苏泽,一位对技术充满热情的探索者和分享者。🚀🚀 🌟特别推荐给大家我的最新专栏《Spring 狂野之旅:底层原理高级进阶》 &#x1f680…

Bicycles(变形dijkstra,动态规划思想)

Codeforces Round 918 (Div. 4) G. Bicycles G. Bicycles 题意: 斯拉夫的所有朋友都打算骑自行车从他们住的地方去参加一个聚会。除了斯拉维奇,他们都有一辆自行车。他们可以经过 n n n 个城市。他们都住在城市 1 1 1 ,想去参加位于城市…

c++实现栈和队列类

c实现栈和队列类 栈(Stack)Stack示意图Stack.cpp 队列(queue)queue 示意图queue.cpp 栈(Stack) Stack示意图 Stack.cpp #pragma once #include "ListStu.cpp"template<typename T> class Stack { public: /* * void push(T& tDate)* 参数一 &#xff1a;…

Android和Linux的开发差异

最近开始投入Android的怀抱。说来惭愧&#xff0c;08年就听说这东西&#xff0c;当时也有同事投入去看&#xff0c;因为恶心Java&#xff0c;始终对这玩意无感&#xff0c;没想到现在不会这个嵌入式都快要没法搞了。为了不中年失业&#xff0c;所以只能回过头又来学。 首先还是…

硬盘销毁:如何彻底销毁硬盘 文件销毁 数据销毁 物料销

大家对于办公这个词并不陌生&#xff0c;大多数人都知道办公就是对公事的处理&#xff0c;不是私人的事情处理。办公中会出现很多文件&#xff0c;很多的资料&#xff0c;那么办公室的文件销毁是如何处理的呢&#xff1f; 办公文件对于办公的人来说都是非常重视的&#xff0c;…

AI学习(5):PyTorch-核心模块(Autograd):自动求导

1.介绍 在深度学习中&#xff0c;自动求导是一项核心技术&#xff0c;它使得我们能够方便地计算梯度并优化模型参数。PyTorch 提供了一个强大的自动求导模块(Autograd)&#xff0c;它可以自动计算张量的导数得出梯度信息&#xff0c;同时也支持高阶导数计算。 1.1 概念词 在学…

微服务-商城订单服务项目

文章目录 一、需求二、分析三、设计四、编码4.1 商品服务4.2 订单服务4.3 分布式事务4.4 订单超时 商品、购物车 商品服务&#xff1a; 1.全品类购物平台 SPU:Standard Product Unit 标准化产品单元。是商品信息聚合的最小单位。是一组可复用、易检索的标准化信息的集合&#x…

EMR StarRocks实战——Mysql数据实时同步到SR

文章摘抄阿里云EMR上的StarRocks实践&#xff1a;《基于实时计算Flink使用CTAS&CDAS功能同步MySQL数据至StarRocks》 前言 CTAS可以实现单表的结构和数据同步&#xff0c;CDAS可以实现整库同步或者同一库中的多表结构和数据同步。下文主要介绍如何使用Flink平台和E-MapRed…

【程序员英语】【美语从头学】初级篇(入门)(笔记)Lesson 16 At the Shoe Store 在鞋店

《美语从头学初级入门篇》 注意&#xff1a;被 删除线 划掉的不一定不正确&#xff0c;只是不是标准答案。 文章目录 Lesson 16 At the Shoe Store 在鞋店对话A对话B笔记会话A会话B替换 Lesson 16 At the Shoe Store 在鞋店 对话A A: Do you have these shoes in size 8? B:…

如何运行github上的项目

为了讲明白这个过程&#xff0c;特意做了一个相当来说比较好读懂的原理图&#xff0c;希望和我一样初学的小伙伴也能很快上手哈&#x1f60a; 在Github中找到想要部署的项目&#xff0c;这里以BartoszJarocki/CV&#xff08;线上简历&#x1f4c4;&#xff09;项目为例 先从头…

经典Go知识点总结

开篇推荐 来来来,老铁们,男人女人都需要的技术活 拿去不谢:远程调试,发布网站到公网演示,远程访问内网服务,游戏联机 推荐链接 1.无论sync.Mutex还是其衍生品都会提示不能复制,但是能够编译运行 加锁后复制变量&#xff0c;会将锁的状态也复制&#xff0c;所以 mu1 其实是已…

4核8G服务器并发数多少?性能如何?

腾讯云4核8G服务器支持多少人在线访问&#xff1f;支持25人同时访问。实际上程序效率不同支持人数在线人数不同&#xff0c;公网带宽也是影响4核8G服务器并发数的一大因素&#xff0c;假设公网带宽太小&#xff0c;流量直接卡在入口&#xff0c;4核8G配置的CPU内存也会造成计算…

React Switch用法及手写Switch实现

问&#xff1a;如果注册的路由特别多&#xff0c;找到一个匹配项以后还会一直往下找&#xff0c;我们想让react找到一个匹配项以后不再继续了&#xff0c;怎么处理&#xff1f;答&#xff1a;<Switch>独特之处在于它只绘制子元素中第一个匹配的路由元素。 如果没有<Sw…

[极客大挑战 2019]LoveSQL1 题目分析与详解

一、题目简介&#xff1a; 二、通关思路&#xff1a; 1、首先查看页面源代码&#xff1a; 我们发现可以使用工具sqlmap来拿到flag&#xff0c;我们先尝试手动注入。 2、 打开靶机&#xff0c;映入眼帘的是登录界面&#xff0c;首先尝试万能密码能否破解。 username: 1 or 11…

IDEA如何开启Dashboard

普通的面板 Run Dashboard面板 修改配置文件 找到项目的.idea文件夹 点击编辑workspace.xml文件 添加下方代码 <component name"RunDashboard"><option name"ruleStates"><list><RuleState><option name"name" valu…

雾锁王国服务器配置怎么选择?阿里云和腾讯云

雾锁王国/Enshrouded服务器CPU内存配置如何选择&#xff1f;阿里云服务器网aliyunfuwuqi.com建议选择8核32G配置&#xff0c;支持4人玩家畅玩&#xff0c;自带10M公网带宽&#xff0c;1个月90元&#xff0c;3个月271元&#xff0c;幻兽帕鲁服务器申请页面 https://t.aliyun.com…

通过QScrollArea寻找最后一个弹簧并且设置弹簧大小

项目原因&#xff0c;最近需要通过QScrollArea寻找其中最后一个弹簧并且设置大小和策略&#xff0c;因为无法直接调用UI指针&#xff0c;所以只能用代码寻找。 直接上代码&#xff1a; if (m_scrollArea){int iScrollWidth m_labelSelectedTitle->width();m_scrollArea-&g…