AI学习记录 - Word2Vec 超详细解析

news2024/9/21 12:47:03

创作不易,点个赞

我们有一堆文本,词汇拆分

sentences = ["jack like dog", "jack like cat", "jack like animal",
  "dog cat animal", "banana apple cat dog like", "dog fish milk like",
  "dog cat animal like", "jack like apple", "apple like", "jack like banana",
  "apple banana jack movie book music like", "cat dog hate", "cat dog like"]
sentence_list = " ".join(sentences).split() # ['jack', 'like', 'dog']
vocab = list(set(sentence_list))
word2idx = {w:i for i, w in enumerate(vocab)}
print("word2idx", word2idx)
vocab_size = len(vocab)
print("vocab_size", vocab_size)

打印如下,一共有13个词汇:

word2idx {'banana': 0, 'animal': 1, 'hate': 2, 'like': 3, 'jack': 4, 'dog': 5, 'fish': 6, 'apple': 7, 'book': 8, 'milk': 9, 'music': 10, 'cat': 11, 'movie': 12}
vocab_size 13

生成训练集合,生成规则是【目标词,上一个词】,【目标词,下一个词】,下面使用的是下标

skip_grams = []
for idx in range(C, len(sentence_list) - C):
  center = word2idx[sentence_list[idx]]
  context_idx = list(range(idx - C, idx)) + list(range(idx + 1, idx + C + 1))
  context = [word2idx[sentence_list[i]] for i in context_idx]

  for w in context:
    skip_grams.append([center, w])

print(skip_grams)
print(len(skip_grams))

skip_grams变量打印如下:

[[5, 4], [5, 3], [5, 4], [5, 3], [4, 3], [4, 5], [4, 3], [4, 11], [3, 5], [3, 4], [3, 11], [3, 4], [11, 4], [11, 3], [11, 4], [11, 3], [4, 3], [4, 11], [4, 3], [4, 1], [3, 11], [3, 4], [3, 1], [3, 5], [1, 4], [1, 3], [1, 5], [1, 11], [5, 3], [5, 1], [5, 11], [5, 1], [11, 1], [11, 5], [11, 1], [11, 0], [1, 5], [1, 11], [1, 0], [1, 7], [0, 11], [0, 1], [0, 7], [0, 11], [7, 1], [7, 0], [7, 11], [7, 5], [11, 0], [11, 7], [11, 5], [11, 3], [5, 7], [5, 11], [5, 3], [5, 5],

我们现在有了训练集,那么现在需要构造模型结构,假设原one-hot编码是13位,训练出来的词向量位2位,这样子我们就减少了词汇的维度并建立词与词的关联,如下图:

在这里插入图片描述

代码为: W矩阵就是我们需要训练的矩阵,V矩阵在训练完成之后我们就丢弃了。

class Word2Vec(nn.Module):
  def __init__(self):
    super(Word2Vec, self).__init__()
    self.W = nn.Parameter(torch.randn(vocab_size, m).type(dtype))
    self.V = nn.Parameter(torch.randn(m, vocab_size).type(dtype))

  def forward(self, X):
    # X : [batch_size, vocab_size]
    hidden = torch.mm(X, self.W) # [batch_size, m]
    output = torch.mm(hidden, self.V) # [batch_size, vocab_size]
    return output

反向传播阶段照样是使用 CrossEntropyLoss 计算误差,关于 CrossEntropyLoss 函数在上一个章节有介绍,利用实际输出和标签的损失进行反向传播,如果你想了解权重怎么调整,继续往前翻找也可查看。

在这里插入图片描述

完整的word2Vec代码


import torch
import numpy as np
import torch.nn as nn
import torch.optim as optimizer
import torch.utils.data as Data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.FloatTensor

sentences = ["jack like dog", "jack like cat", "jack like animal",
  "dog cat animal", "banana apple cat dog like", "dog fish milk like",
  "dog cat animal like", "jack like apple", "apple like", "jack like banana",
  "apple banana jack movie book music like", "cat dog hate", "cat dog like"]
sentence_list = " ".join(sentences).split() # ['jack', 'like', 'dog']
vocab = list(set(sentence_list))
word2idx = {w:i for i, w in enumerate(vocab)}
vocab_size = len(vocab)

# model parameters
C = 2 # window size
batch_size = 8
m = 2 # word embedding dim

skip_grams = []
for idx in range(C, len(sentence_list) - C):
  center = word2idx[sentence_list[idx]]
  context_idx = list(range(idx - C, idx)) + list(range(idx + 1, idx + C + 1))
  context = [word2idx[sentence_list[i]] for i in context_idx]

  for w in context:
    skip_grams.append([center, w])

def make_data(skip_grams):
  input_data = []
  output_data = []
  for a, b in skip_grams:
    input_data.append(np.eye(vocab_size)[a])
    output_data.append(b)
  return input_data, output_data

input_data, output_data = make_data(skip_grams)
input_data, output_data = torch.Tensor(input_data), torch.LongTensor(output_data)
dataset = Data.TensorDataset(input_data, output_data)
loader = Data.DataLoader(dataset, batch_size, True)

class Word2Vec(nn.Module):
  def __init__(self):
    super(Word2Vec, self).__init__()
    self.W = nn.Parameter(torch.randn(vocab_size, m).type(dtype))
    self.V = nn.Parameter(torch.randn(m, vocab_size).type(dtype))

  def forward(self, X):
    # X : [batch_size, vocab_size]
    hidden = torch.mm(X, self.W) # [batch_size, m]
    output = torch.mm(hidden, self.V) # [batch_size, vocab_size]
    return output

model = Word2Vec().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optim = optimizer.Adam(model.parameters(), lr=1e-3)

for epoch in range(2000):
  for i, (batch_x, batch_y) in enumerate(loader):
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)
    pred = model(batch_x)
    loss = loss_fn(pred, batch_y)

    if (epoch + 1) % 1000 == 0:
      print(epoch + 1, i, loss.item())
    
    optim.zero_grad()
    loss.backward()
    optim.step()

import matplotlib.pyplot as plt
for i, label in enumerate(vocab):
  W, WT = model.parameters()
  x,y = float(W[i][0]), float(W[i][1])
  plt.scatter(x, y)
  plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')
plt.show()


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2051735.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

URP平面阴影合批处理 shadow

闲谈 相信大家在日常工作中发现了一个问题 , urp下虽然可以做到3个Pass 去写我们想要的效果,但是,不能合批(不能合批,那不是我们CPU要干冒烟~!) 好家伙,熊猫老师的偏方来了 &#x…

【数值方法-Python实现】Crout分解+追赶法实现

涉及Crout分解、追赶法的线性方程组求解方法的Python实现。 原文链接:https://www.cnblogs.com/aksoam/p/18366119 Codes def CroutLU(A:np.ndarray)->Tuple[np.ndarray,np.ndarray]:"""Crout LU分解算法,ALUinput:A: (n,n) np.ndarray,方阵out…

DrissionPage自动化获取城市数据内容

一、获取页面内容 二、最终结果 上海市 约收录140个指标 查看98075次 人均GDP 153299元 公交车 17899辆 户籍人口 1469.3万人 三、代码 from DrissionPage._pages.chromium_page import ChromiumPage import time page ChromiumPage() page.get(https://www.swguancha.com/…

【Delphi】中多显示器操作基本知识点

提要: 目前随着计算机的发展,4K显示器已经逐步在普及,笔记本的显示器分辨率也都已经超过2K,多显示器更是普及速度很快。本文介绍下Delphi中操作多显示器的基本知识点(Windows系统),这些知识点在…

UniFab 是一款由人工智慧驅動的視訊增強器+ crack

UniFab 是一款功能强大的视频处理工具,包括 10 个基于 AI 的功能。使用 UniFab,您可以提高视频和音频质量、将视频转换为不同的格式、根据自己的喜好编辑视频等等。以下是适用于 Windows 的 UniFab 程序的简要说明: 视频转换器。UniFab 支持 1000 多种视频格式的转换,包括 …

构建自己的图数据集

代码: import warnings warnings.filterwarnings("ignore") import torch from torch_geometric.data import Datax torch.tensor([[2,1],[5,6],[3,7],[12,0]],dtypetorch.float) y torch.tensor([0,1,0,1],dtypetorch.float)#定义边 edge_index torc…

⌈ 传知代码 ⌋ DETR[端到端目标检测]

💛前情提要💛 本文是传知代码平台中的相关前沿知识与技术的分享~ 接下来我们即将进入一个全新的空间,对技术有一个全新的视角~ 本文所涉及所有资源均在传知代码平台可获取 以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦&#x…

Leetcode3232. 判断是否可以赢得数字游戏

Every day a Leetcode 题目来源:3232. 判断是否可以赢得数字游戏 解法1:3232. 判断是否可以赢得数字游戏 用一个 sum1 统计个位数的和,sum2 统计十位数的和。 只要 sum1 和 sum2 不相等,Alice 拿大的就能赢得这场游戏。 代码…

【论文阅读】HuatuoGPT-II, One-stage Training for Medical Adaption of LLMs

总体概要 本文深入探讨了一款专为医疗领域设计的大规模语言模型——HuatuoGPT-II的创新、性能与应用。HuatuoGPT-II采用统一的单阶段训练流程,将传统的继续预训练和监督微调整合,有效解决了医疗数据的异质性问题,包括语言、体裁和格式差异&a…

【STM32单片机_(HAL库)】3-2-1【中断EXTI】【电动车报警器项目】继电器定时开闭

1.硬件 STM32单片机最小系统继电器模块 2.软件 继电器模块alarm驱动文件添加GPIO常用函数main.c程序 #include "sys.h" #include "delay.h" #include "led.h" #include "alarm.h"int main(void) {HAL_Init(); …

硬件面试经典 100 题(71~90 题)

71、请问下图电路的作用是什么? 该电路实现 IIC 信号的电平转换(3.3V 和 5V 电平转换),并且是双向通信的。 上下两路是一样的,只分析 SDA 一路: 1) 从左到右通信(SDA2 为输入状态&…

同一台电脑同时连接使用Gitee(码云)和Github

1、添加对应的密钥 ssh-keygen -t rsa -C "your_emailexample.com" -f ~/.ssh/github_id-rsa //生成github秘钥 ssh-keygen -t rsa -C "your_emailexample.com" -f ~/.ssh/gitee_id-rsa //生成码云秘钥 2、在 ~/.ssh 文件里会生成对应的文件 文件夹里会…

[k8s源码]12.远程调试dlv

在Windows/Mac宿主机上,使用GoLand的IDE进行开发,但是如何将这些代码直接运行在k8s集群中并看到运行效果呢,这里有一个远程调试工具dlv。 图中展示了dlv的工作方式。GoLand IDE中包含Editor(编辑器)和Debugger(调试器)组件,其中De…

深度学习基础之前馈神经网络

目录 基本结构和工作原理 神经元和权重 激活函数 深度前馈网络 应用场景 优缺点 深度前馈神经网络与卷积神经网络(CNN)和循环神经网络(RNN)的具体区别和联系是什么? 具体区别 联系 如何有效解决前馈神经网络…

探索Python的工业通信之光:pymodbus的奇妙之旅

文章目录 探索Python的工业通信之光:pymodbus的奇妙之旅背景:为何选择pymodbus?pymodbus是什么?如何安装pymodbus?5个简单的库函数使用方法3个场景使用示例常见bug及解决方案总结 探索Python的工业通信之光&#xff1a…

炒作将引发人工智能寒冬

我们似乎经常看到人工智能的进步被吹捧为机器真正变得智能的一大飞跃。我将在这里挑选其中的一个例子,并确切解释为什么这种态度会为人工智能的未来埋下隐患。 这很酷,这是一个非常困难且非常具体的问题,这个团队花了3 年时间才解决。他们一定…

结合GPT与Python实现端口检测工具(含多线程)

端口检测器是一个非常实用的网络工具,它主要用于检测服务器或本地计算机上的特定端口是否处于开放状态。通过这个工具,你可以快速识别和诊断网络连接问题,确保关键服务的端口能够正常接收和处理数据。这对于网络管理员和开发者来说是一个不可…

【Linux修行路】基础I/O——重定向的实现原理

目录 ⛳️推荐 一、再来理解重定向 1.1 输出重定向效果演示 1.2 重定向的原理 1.3 dup2 1.4 输入重定向效果演示 1.5 输入重定向代码实现 二、再来理解标准输出和标准错误 2.1 同时对标准输出和标准错误进行重定向 2.2 将标准输出和标准错误重定向到同一个文件 三、…

版本更新 《坚持学习计时器》软件V3.1 更新内容:自动实时显出

🌟 嗨,我是命运之光! 🌍 2024,每日百字,记录时光,感谢有你一路同行。 🚀 携手启航,探索未知,激发潜能,每一步都意义非凡。 版本更新 《坚持学习…

【统计字符数量】统计出每种字符的数量

输入一行字符&#xff0c;分别统计出其中英文字母、空格、数字和其他字符的个数&#xff0c;使用C语言实现&#xff0c; 具体代码&#xff1a; #include<stdio.h>int main(){char c;int letters0,space0,digit0,others0;printf("请输入一行字符&#xff1a; "…