pytorch基于GloVe实现的词嵌入

news2025/2/4 22:08:11

PyTorch 实现 GloVe(Global Vectors for Word Representation) 的完整代码,使用 中文语料 进行训练,包括 共现矩阵构建、模型定义、训练和测试


 1. GloVe 介绍

基于词的共现信息(不像 Word2Vec 使用滑动窗口预测)
 适合较大规模的数据(比 Word2Vec 更稳定)
学习出的词向量能捕捉语义信息(如类比关系)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
from scipy.sparse import coo_matrix

# ========== 1. 数据预处理 ==========
corpus = [
    "我们 喜欢 深度 学习",
    "自然 语言 处理 是 有趣 的",
    "人工智能 改变 了 世界",
    "深度 学习 是 人工智能 的 重要 组成部分"
]

# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}

# 计算共现矩阵
window_size = 2
co_occurrence = Counter()

for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_idx in range(len(indices)):
        center_word = indices[center_idx]
        for offset in range(-window_size, window_size + 1):
            context_idx = center_idx + offset
            if 0 <= context_idx < len(indices) and context_idx != center_idx:
                context_word = indices[context_idx]
                co_occurrence[(center_word, context_word)] += 1

# 转换为稀疏矩阵
rows, cols, values = zip(*[(c[0], c[1], v) for c, v in co_occurrence.items()])
X = coo_matrix((values, (rows, cols)), shape=(len(vocab), len(vocab)))


# ========== 2. 定义 GloVe 模型 ==========
class GloVe(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(GloVe, self).__init__()
        self.w_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 中心词嵌入
        self.c_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 上下文词嵌入
        self.w_bias = nn.Embedding(vocab_size, 1)  # 中心词偏置
        self.c_bias = nn.Embedding(vocab_size, 1)  # 上下文词偏置
        nn.init.xavier_uniform_(self.w_embeddings.weight)
        nn.init.xavier_uniform_(self.c_embeddings.weight)

    def forward(self, center, context, co_occur):
        w_emb = self.w_embeddings(center)
        c_emb = self.c_embeddings(context)
        w_bias = self.w_bias(center).squeeze()
        c_bias = self.c_bias(context).squeeze()
        dot_product = (w_emb * c_emb).sum(dim=1)
        loss = (dot_product + w_bias + c_bias - torch.log(co_occur + 1e-8)) ** 2
        return loss.mean()


# 初始化模型
embedding_dim = 10
model = GloVe(len(vocab), embedding_dim)

# ========== 3. 训练 GloVe ==========
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100

# 转换数据
co_occurrence_tensor = torch.tensor(X.data, dtype=torch.float)
pairs = list(zip(X.row, X.col, co_occurrence_tensor))

for epoch in range(num_epochs):
    total_loss = 0
    np.random.shuffle(pairs)
    for center, context, co_occur in pairs:
        optimizer.zero_grad()
        loss = model(
            torch.tensor([center], dtype=torch.long),
            torch.tensor([context], dtype=torch.long),
            torch.tensor([co_occur], dtype=torch.float)  # 修正数据类型
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")

# ========== 4. 获取词向量 ==========
word_vectors = model.w_embeddings.weight.data.numpy()


# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):
    if word not in word2idx:
        return "单词不在词汇表中"

    word_vec = word_vectors[word2idx[word]].reshape(1, -1)
    similarities = np.dot(word_vectors, word_vec.T).squeeze()
    similar_idx = similarities.argsort()[::-1][1:top_n + 1]
    return [(idx2word[idx], similarities[idx]) for idx in similar_idx]


# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:
    print(f"【{word}】的相似单词:", most_similar(word))

数据预处理

  • 分词(使用 jieba.cut()
  • 构建共现矩阵(计算窗口内的单词共现频率)
  • 使用稀疏矩阵存储(提高计算效率)

GloVe 模型

  • Embedding 训练词向量(中心词和上下文词分开)
  • Bias 变量 用于调整预测值
  • 损失函数 最小化 log(共现次数) 与词向量点积的差值

 计算词向量相似度

  • 使用 cosine similarity
  • 找出 top_n 最相似的单词

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2292002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OPENPPP2 —— VMUX_NET 多路复用原理剖析

在阅读本文之前&#xff0c;必先了解以下几个概念&#xff1a; 1、MUX&#xff08;Multiplexer&#xff09;&#xff1a;合并多个信号到单一通道。 2、DEMUX&#xff08;Demultiplexer&#xff09;&#xff1a;从单一通道分离出多个信号。 3、单一通道&#xff0c;可汇聚多个…

语言月赛 202412【正在联系教练退赛】题解(AC)

》》》点我查看「视频」详解》》》 [语言月赛 202412] 正在联系教练退赛 题目背景 在本题中&#xff0c;我们称一个字符串 y y y 是一个字符串 x x x 的子串&#xff0c;当且仅当从 x x x 的开头和结尾删去若干个&#xff08;可以为 0 0 0 个&#xff09;字符后剩余的字…

【数据结构】_链表经典算法OJ:复杂链表的复制

目录 1. 题目链接及描述 2. 解题思路 3. 程序 1. 题目链接及描述 题目链接&#xff1a;138. 随机链表的复制 - 力扣&#xff08;LeetCode&#xff09; 题目描述&#xff1a; 给你一个长度为 n 的链表&#xff0c;每个节点包含一个额外增加的随机指针 random &#xff0c;…

python的pre-commit库的使用

在软件开发过程中&#xff0c;保持代码的一致性和高质量是非常重要的。pre-commit 是一个强大的工具&#xff0c;它可以帮助我们在提交代码到版本控制系统&#xff08;如 Git&#xff09;之前自动运行一系列的代码检查和格式化操作。通过这种方式&#xff0c;我们可以确保每次提…

【C语言入门】解锁核心关键字的终极奥秘与实战应用(三)

目录 一、auto 1.1. 作用 1.2. 特性 1.3. 代码示例 二、register 2.1. 作用 2.2. 特性 2.3. 代码示例 三、static 3.1. 修饰局部变量 3.2. 修饰全局变量 3.3. 修饰函数 四、extern 4.1. 作用 4.2. 特性 4.3. 代码示例 五、volatile 5.1. 作用 5.2. 代码示例…

音标-- 02-- 重音 音节 变音

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 国际音标1.重音2.音节3.变音 国际音标 1.重音 2.音节 3.变音

[STM32 标准库]EXTI应用场景 功能框图 寄存器

一、EXTI 外部中断在嵌入式系统中有广泛的应用场景&#xff0c;如按钮开关控制&#xff0c;传感器触发&#xff0c;通信接口中断等。其原理都差不多&#xff0c;STM32会对外部中断引脚的边沿进行检测&#xff0c;若检测到相应的边沿会触发中断&#xff0c;在中断中做出相应的处…

C语言练习【互斥锁、信号量线程同步、条件变量实现生产者消费者模型】

练习1 请使用互斥锁 和 信号量分别实现5个线程之间的同步 互斥锁实现同步 #include <stdio.h> #include <string.h> #include <unistd.h> #include <stdlib.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h>…

w190工作流程管理系统设计与实现

&#x1f64a;作者简介&#xff1a;多年一线开发工作经验&#xff0c;原创团队&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取&#xff0c;记得注明来意哦~&#x1f339;赠送计算机毕业设计600个选题excel文…

linux下ollama更换模型路径

Linux下更换Ollama模型下载路径指南   在使用Ollama进行AI模型管理时&#xff0c;有时需要根据实际需求更改模型文件的存储路径。本文将详细介绍如何在Linux系统中更改Ollama模型的下载路径。 一、关闭Ollama服务   在更改模型路径之前&#xff0c;需要先停止Ollama服务。…

编程题-电话号码的字母组合(中等)

题目&#xff1a; 给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 解法一&#xff08;哈希表动态添加&#xff09;&#x…

浅谈《图解HTTP》

感悟 滑至尾页的那一刻&#xff0c;内心突兀的涌来一阵畅快的感觉。如果说从前对互联网只是懵懵懂懂&#xff0c;但此刻却觉得她是如此清晰而可爱的呈现在哪里。 介绍中说&#xff0c;《图解HTTP》适合作为第一本网络协议书。确实&#xff0c;它就像一座桥梁&#xff0c;连接…

架构知识整理与思考(其四)

书接上回 建议&#xff0c;没有看过上一章的可以看一下&#xff0c;上一章“架构知识整理与思考&#xff08;其二&#xff09;” 感觉这都成链表了。 三生万物 软件架构 终于&#xff0c;我们进入了具体的软件架构讨论中。 软件架构是什么&#xff1f;相关定义如下&#xf…

【C++】B2124 判断字符串是否为回文

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 &#x1f4af;前言&#x1f4af;题目描述输入格式&#xff1a;输出格式&#xff1a;样例&#xff1a; &#x1f4af;方法一&#xff1a;我的第一种做法思路代码实现解析 &#x1f4af;方法二&#xff1a;我…

基于Spring Security 6的OAuth2 系列之八 - 授权服务器--Spring Authrization Server的基本原理

之所以想写这一系列&#xff0c;是因为之前工作过程中使用Spring Security OAuth2搭建了网关和授权服务器&#xff0c;但当时基于spring-boot 2.3.x&#xff0c;其默认的Spring Security是5.3.x。之后新项目升级到了spring-boot 3.3.0&#xff0c;结果一看Spring Security也升级…

算法题(48):反转链表

审题&#xff1a; 需要我们将链表反转并返回头结点地址 思路&#xff1a; 一般在面试中&#xff0c;涉及链表的题会主要考察链表的指向改变&#xff0c;所以一般不会允许我们改变节点val值。 这里是单向链表&#xff0c;如果要把指向反过来则需要同时知道前中后三个节点&#x…

梯度、梯度下降、最小二乘法

在求解机器学习算法的模型参数&#xff0c;即无约束优化问题时&#xff0c;梯度下降是最常采用的方法之一&#xff0c;另一种常用的方法是最小二乘法。 1. 梯度和梯度下降 在微积分里面&#xff0c;对多元函数的参数求∂偏导数&#xff0c;把求得的各个参数的偏导数以向量的形式…

独立开发者小程序开发变现思路

随着移动互联网的发展&#xff0c;小程序已成为许多独立开发者展示才能和实现收入的重要平台。作为一种轻量级的应用形态&#xff0c;小程序具有开发成本低、用户体验好、传播效率高等优势&#xff0c;为独立开发者提供了多种变现方式。然而&#xff0c;要想实现真正的盈利&…

软件测试 - 概念篇

目录 1. 需求 1.1 用户需求 1.2 软件需求 2. 开发模型 2.1 软件的生命周期 2.2 常见开发模型 2.2.1 瀑布模型 2.2.2 螺旋模型 1. 需求 对于软件开发而言, 需求分为以下两种: 用户需求软件需求 1.1 用户需求 用户需求, 就是用户提出的需求, 没有经过合理的评估, 通常…

使用SpringBoot发送邮件|解决了部署时连接超时的bug|网易163|2025

使用SpringBoot发送邮件 文章目录 使用SpringBoot发送邮件1. 获取网易邮箱服务的授权码2. 初始化项目maven部分web部分 3. 发送邮件填写配置EmailSendService [已解决]部署时连接超时附&#xff1a;Docker脚本Dockerfile创建镜像启动容器 1. 获取网易邮箱服务的授权码 温馨提示…