训练自己的中文word2vec(词向量)--skip-gram方法

news2024/9/21 13:32:47

训练自己的中文word2vec(词向量)–skip-gram方法

什么是词向量

​ 将单词映射/嵌入(Embedding)到一个新的空间,形成词向量,以此来表示词的语义信息,在这个新的空间中,语义相同的单词距离很近。

Skip-Gram方法(本次使用方法)

​ 以某个词为中心,分别计算该中心词前后可能出现其他词的各个概率,即给定input word来预测上下文。

Image Name

CBOW(Continous Bags Of Words,CBOW)

​ CBOW根据某个词前面的n个词、或者前后各n个连续的词,来计算某个词出现的概率,即给定上下文,来预测input word。相比Skip-Gram,CBOW更快一些。

本次使用 Skip-Gram方法和三国演义第一章作为数据,训练32维中文词向量。

数据代码下载链接见文末

导入库

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import re
import collections
import numpy as np
import jieba
#指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

导入数据

因为算力问题,我在此出只截取三国演义的第一章作为示例数据。

training_file = '/home/mw/input/sanguo5529/三国演义.txt'

#读取text文件,并选择第一章作为输入文本
def get_ch_lable(txt_file):  
    labels= ""
    with open(txt_file, 'rb') as f:
        for label in f:
            labels =labels+label.decode('utf-8')
        text = re.findall('第1章.*?第2章', labels,re.S)
    return text[0]
training_data =get_ch_lable(training_file)
# print(training_data)
print("总字数",len(training_data))

总字数 4945

分词

#jieba分词
def fenci(training_data):
    seg_list = jieba.cut(training_data)  # 默认是精确模式  
    training_ci = " ".join(seg_list)
    training_ci = training_ci.split()
    #以空格将字符串分开
    training_ci = np.array(training_ci)
    training_ci = np.reshape(training_ci, [-1, ])
    return training_ci
training_ci =fenci(training_data)
print("总词数",len(training_ci))

总词数 3053

构建词表

def build_dataset(words, n_words):
  count = [['UNK', -1]]
  count.extend(collections.Counter(words).most_common(n_words - 1))
  dictionary = dict()
  for word, _ in count:
    dictionary[word] = len(dictionary)
  data = list()
  unk_count = 0
  for word in words:
    if word in dictionary:
      index = dictionary[word]
    else:
      index = 0  # dictionary['UNK']
      unk_count += 1
    data.append(index)
  count[0][1] = unk_count
  reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
  
  return data, count, dictionary, reversed_dictionary

training_label, count, dictionary, words = build_dataset(training_ci, 3053)
#计算词频
word_count = np.array([freq for _,freq in count], dtype=np.float32)
word_freq = word_count / np.sum(word_count)#计算每个词的词频
word_freq = word_freq ** (3. / 4.)#词频变换
words_size = len(dictionary)
print("字典词数",words_size) 
print('Sample data', training_label[:10], [words[i] for i in training_label[:10]])

字典词数 1456
Sample data [100, 305, 140, 306, 67, 101, 307, 308, 46, 27]
[‘第’, ‘1’, ‘章’, ‘宴’, ‘桃园’, ‘豪杰’, ‘三’, ‘结义’, ‘斩’, ‘黄巾’]

制作数据集

C = 3 
num_sampled = 64  # 负采样个数   
BATCH_SIZE = 32  
EMBEDDING_SIZE = 32  #想要的词向量长度

class SkipGramDataset(Dataset):
    def __init__(self, training_label, word_to_idx, idx_to_word, word_freqs):
        super(SkipGramDataset, self).__init__()
        self.text_encoded = torch.Tensor(training_label).long()
        self.word_to_idx = word_to_idx
        self.idx_to_word = idx_to_word
        self.word_freqs = torch.Tensor(word_freqs)

    def __len__(self):
        return len(self.text_encoded)

    def __getitem__(self, idx):
        idx = min( max(idx,C),len(self.text_encoded)-2-C)#防止越界
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+1+C))
        pos_words = self.text_encoded[pos_indices] 
        #多项式分布采样,取出指定个数的高频词
        neg_words = torch.multinomial(self.word_freqs, num_sampled+2*C, False)#True)
        #去掉正向标签
        neg_words = torch.Tensor(np.setdiff1d(neg_words.numpy(),pos_words.numpy())[:num_sampled]).long()
        return center_word, pos_words, neg_words


print('制作数据集...')
train_dataset = SkipGramDataset(training_label, dictionary, words, word_freq)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,drop_last=True, shuffle=True)

制作数据集…

#将数据集转化成迭代器
sample = iter(dataloader)	
#从迭代器中取出一批次样本				
center_word, pos_words, neg_words = sample.next()				
print(center_word[0],words[np.compat.long(center_word[0])],[words[i] for i in pos_words[0].numpy()])

模型构建

class Model(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(Model, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size

        initrange = 0.5 / self.embed_size
        self.in_embed = nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.in_embed.weight.data.uniform_(-initrange, initrange)

    def forward(self, input_labels, pos_labels, neg_labels):
        input_embedding = self.in_embed(input_labels)
                
        pos_embedding = self.in_embed(pos_labels)
        neg_embedding = self.in_embed(neg_labels)
        
        log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze()
        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze()

        log_pos = F.logsigmoid(log_pos).sum(1)
        log_neg = F.logsigmoid(log_neg).sum(1)
        loss = log_pos + log_neg
        return -loss

model = Model(words_size, EMBEDDING_SIZE).to(device)
model.train()

valid_size = 32
valid_window = words_size/2  # 取样数据的分布范围.
valid_examples = np.random.choice(int(valid_window), valid_size, replace=False)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
NUM_EPOCHS = 10

开始训练

for e in range(NUM_EPOCHS):
    for ei, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        input_labels = input_labels.to(device)
        pos_labels = pos_labels.to(device)
        neg_labels = neg_labels.to(device)

        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()
        optimizer.step()

        if ei % 20 == 0:
            print("epoch: {}, iter: {}, loss: {}".format(e, ei, loss.item()))
    if e %40 == 0:           
        norm = torch.sum(model.in_embed.weight.data.pow(2),-1).sqrt().unsqueeze(1)
        normalized_embeddings = model.in_embed.weight.data / norm
        valid_embeddings = normalized_embeddings[valid_examples]
        
        similarity = torch.mm(valid_embeddings, normalized_embeddings.T)
        for i in range(valid_size):
            valid_word = words[valid_examples[i]]
            top_k = 8  # 取最近的排名前8的词
            nearest = (-similarity[i, :]).argsort()[1:top_k + 1]  #argsort函数返回的是数组值从小到大的索引值
            log_str = 'Nearest to %s:' % valid_word  
            for k in range(top_k):
                close_word = words[nearest[k].cpu().item()]
                log_str = '%s,%s' % (log_str, close_word)
            print(log_str)
epoch: 0, iter: 0, loss: 48.52019500732422
epoch: 0, iter: 20, loss: 48.51792526245117
epoch: 0, iter: 40, loss: 48.50772476196289
epoch: 0, iter: 60, loss: 48.50897979736328
epoch: 0, iter: 80, loss: 48.45783996582031
Nearest to 伏山:,,,,遥望,操忽心生,江渚上,,提刀
Nearest to 与:,,将次,盖地,,下文,,妖术,玄德
Nearest to 必获:,听调,直取,各处,一端,奏帝,必出,遥望,入帐
Nearest to 郭胜十人:,南华,因为,,贼战,告变,,,玄德遂
Nearest to 官军:,秀才,统兵,说起,军齐出,四散,放荡,一彪,云长
Nearest to 碧眼:,名备,刘焉然,大汉,卷入,大胜,老人,重枣,左有
Nearest to 天书:,泛溢,约期,而进,龚景,因起,车盖,遂解,震
Nearest to 转头:,近因,直取,,七尺,,备下,大汉,齐声
Nearest to 张宝:,侯览,惊告嵩,誓毕,帝惊,呼风唤雨,狂风,大将军,曾
Nearest to 玄德幼:,直取,近闻,刘焉令,几度,,临江仙,右有翼德,左右两
Nearest to 不祥:,无数,调兵,,刘备,玄德谢,原来,八尺,共
Nearest to 靖:,操忽心生,此病,赵忠,刘焉然,,庄田,,传至
Nearest to 五千:,岂可,丹凤眼,北行,听罢,,性命,,之囚
Nearest to 日非:,赵忠,,闻得,,破贼,早丧,书报,忽起
Nearest to 徐:,震怒,我气,卢植,,结为,,燕颔虎须,。
Nearest to 五百余:,帝惊,,本部,,神功,桓帝,滚滚,左右两
Nearest to 而:,转头,卷入,,近因,大商,,人公,天子
Nearest to 去:,封谞,夏恽,周末,,嵩信,广宗,人氏,民心
Nearest to 上:,,,陷邕,四年,关羽,直赶,九尺,伏山
Nearest to ::,,,,,兄弟,,来代,我答
Nearest to 后:,必获,阁下,,手起,祭礼,侍奉,各处,奏帝
Nearest to 因起:,帝览奏,,,汝可引,夺路,一把,是非成败,卷入
Nearest to 骤起:,挟恨,张宝称,明公宜,,一统天下,,,玄德请
Nearest to 汉室:,六月,临江仙,今汉运,手起,威力,抹额,讹言,提刀
Nearest to 云游四方:,背义忘恩,,渔樵,地公,扬鞭,,故冒姓,截住
Nearest to 桓:,,赵忠,刘焉然,左有,刘备,名备,二帝,游荡
Nearest to 二字于:,操故,,白土,左右两,张角本,赏劳,当时,梁上
Nearest to 人出:,,五十匹,,奏帝,梁上,九尺,六月,大汉
Nearest to 大浪:,卷入,临江仙,听调,汉武时,左有,束草,围城,及
Nearest to 青:,夺路,,贩马,师事,围城,卷入,大胜,客人
Nearest to 郎蔡邕:,浊酒,近闻,六月,角战于,中郎将,,转头,众大溃
Nearest to 二月:,马舞刀,国谯郡,只见,内外,郎蔡邕,,落到,汝得
epoch: 1, iter: 0, loss: 48.46757888793945
epoch: 1, iter: 20, loss: 48.42853546142578
epoch: 1, iter: 40, loss: 48.35804748535156
epoch: 1, iter: 60, loss: 48.083805084228516
epoch: 1, iter: 80, loss: 48.1635856628418
epoch: 2, iter: 0, loss: 47.89817428588867
epoch: 2, iter: 20, loss: 48.067501068115234
epoch: 2, iter: 40, loss: 48.6464729309082
epoch: 2, iter: 60, loss: 47.825260162353516
epoch: 2, iter: 80, loss: 48.07224655151367
epoch: 3, iter: 0, loss: 48.15058898925781
epoch: 3, iter: 20, loss: 47.26418685913086
epoch: 3, iter: 40, loss: 47.87504577636719
epoch: 3, iter: 60, loss: 48.74541473388672
epoch: 3, iter: 80, loss: 48.01288986206055
epoch: 4, iter: 0, loss: 47.257896423339844
epoch: 4, iter: 20, loss: 48.337745666503906
epoch: 4, iter: 40, loss: 47.70765686035156
epoch: 4, iter: 60, loss: 48.57493591308594
epoch: 4, iter: 80, loss: 48.206268310546875
epoch: 5, iter: 0, loss: 47.139137268066406
epoch: 5, iter: 20, loss: 48.70667266845703
epoch: 5, iter: 40, loss: 47.97750473022461
epoch: 5, iter: 60, loss: 48.098899841308594
epoch: 5, iter: 80, loss: 47.778892517089844
epoch: 6, iter: 0, loss: 47.86349105834961
epoch: 6, iter: 20, loss: 47.77979278564453
epoch: 6, iter: 40, loss: 48.67324447631836
epoch: 6, iter: 60, loss: 48.117042541503906
epoch: 6, iter: 80, loss: 48.69907760620117
epoch: 7, iter: 0, loss: 47.63265609741211
epoch: 7, iter: 20, loss: 47.82151794433594
epoch: 7, iter: 40, loss: 48.54405212402344
epoch: 7, iter: 60, loss: 48.06487274169922
epoch: 7, iter: 80, loss: 48.67494583129883
epoch: 8, iter: 0, loss: 48.053466796875
epoch: 8, iter: 20, loss: 47.872459411621094
epoch: 8, iter: 40, loss: 47.462432861328125
epoch: 8, iter: 60, loss: 48.10865783691406
epoch: 8, iter: 80, loss: 46.380184173583984
epoch: 9, iter: 0, loss: 47.2872314453125
epoch: 9, iter: 20, loss: 48.553428649902344
epoch: 9, iter: 40, loss: 47.00652313232422
epoch: 9, iter: 60, loss: 47.970741271972656
epoch: 9, iter: 80, loss: 48.159828186035156

查看训练好的词向量

final_embeddings = normalized_embeddings
labels = words[10]
print(labels)
print(final_embeddings[10])

玄德
tensor([-0.2620, 0.0660, 0.0464, 0.2948, -0.1974, 0.2471, -0.0893, 0.1720,
-0.1488, 0.0283, -0.1165, 0.2156, -0.1642, -0.2376, -0.0356, -0.0607,
0.1985, -0.2166, 0.2222, 0.2453, -0.1414, -0.0526, 0.1153, -0.1325,
-0.2964, 0.2775, -0.0637, -0.0716, 0.2672, 0.0539, 0.1697, 0.0489])

with open('skip-gram-sanguo.txt', 'a') as f:    
    for i in range(len(words)):
        f.write(words[i] + str(list(final_embeddings.numpy()[i])) + '\n')
f.close()
print('word vectors have written done.')

word vectors have written done.

按照路径/home/mw/project/skip-gram-sanguo.txt查看保存的文件,不一定要保存为txt,我们平常加载的词向量更多是vec格式

Image Name

数据代码下载链接

数据及代码右上角fork后可以免费获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/352040.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

双塔多目标MVKE

MVKE:Mixture of Virtual-Kernel Experts for Multi-Objective User Profile Modeling MVKE论文中是给用户打tag标记,构建用户画像。使用的也是经典的双塔模型,另外在双塔的基础上面叠加了ctr和cvr的多个目标。但是论文最大的创新点是在用户…

基于龙芯 CPU 的气井控制器的软件设计(三)

4.1 系统软件的总体设计 基于龙芯 CPU 的气井控制器的设计需要开发测试硬件模块的测试软件,主要对 RTC 模块、存储器模块、4G 通信、以太网通信、UART 串口以及 AI 模块进行了驱动程序和 应用程序设计。将各个模块设计为不同的任务,龙芯 RTU 软件设计流程…

Redis 监听过期的key(KeyExpirationEventMessageListener)

目录一、简介二、maven依赖三、编码实现3.1、application.properties3.2、Redis配置类3.3、监听器3.4、服务类3.5、工具类四、测试4.1、测试类4.2、单实例4.3、多实例结语一、简介 本文今天主要是讲Redis中对过期key的监听,可能很多小伙伴不会,或者使用会…

day15_常用类

今日内容 上课同步视频:CuteN饕餮的个人空间_哔哩哔哩_bilibili 同步笔记沐沐霸的博客_CSDN博客-Java2301 零、 复习昨日 一、作业 二、代码块[了解] 三、API 四、Object 五、包装类 六、数学和随机 零、 复习昨日 抽象接口修饰符abstractinterface是不是类类接口属性正常属性没…

Leetcode(每日一题)——1139. 最大的以 1 为边界的正方形

摘要 1139. 最大的以 1 为边界的正方形 一、以1为边界的最大正方形 1.1 动态规划 第530题需要正方形所有网格中的数字都是1,只要搞懂动态规划的原理,代码就非常简洁。而这题只要正方形4条边的网格都是1即可,中间是什么数字不用管。 这题…

Hive的安装与配置

一、配置Hadoop环境先看看伪分布式下的集群环境有没有错误的情况:输入命令:start-all.sh jps查看伪分布式的所有进程是否完善二、解压并配置HiveHive压缩包→ https://pan.baidu.com/s/1eOF_ICZV8rV-CEh3nX-7Xw 提取码: m31e 复制这段内容后打开百度网盘…

逆向 xx音乐 aversionid

逆向 xx音乐 aversionid 版本 7.2.0 版本 7.22.0 第一步,charles 抓包 目标字段 aversionid 加固平台 com.stub.StubApp 360加固s.h.e.l.l.S 爱加密com.secneo.apkwrapper.ApplicationWrapper 梆梆加固com.tencent.StubShell.TxAppEntry 腾讯加固 第二步&…

【网络编程】Java快速上手InetAddress类

概念 Java具有较好的网络编程模型/库,其中非常重要的一个API便是InetAddress。在Java.net 网络编程中中有许多类都使用到了InetAddress 这个类代表一个互联网协议(IP)地址。 IP地址是一个32(IPV4)位或128(…

求职季哪种 Python 程序员能拿高薪?

本文以Python爬虫、数据分析、后端、数据挖掘、全栈开发、运维开发、高级开发工程师、大数据、机器学习、架构师这10个岗位,从拉勾网上爬取了相应的职位信息和任职要求,并通过数据分析可视化,直观地展示了这10个职位的平均薪资和学历、工作经…

02 Context的使用

对于 HTTP 服务而言,超时往往是造成服务不可用、甚至系统瘫痪的罪魁祸首。 context 标准库设计思路 为了防止雪崩,context 标准库的解决思路是:在整个树形逻辑链条中,用上下文控制器 Context,实现每个节点的信息传递…

Package ‘oniguruma‘, required by ‘virtual:world‘, not found

一、操作系统环境 OS版本信息:Rocky Linux 9.1 PHP版本:8.0.26 安装的依赖: dnf -y install libXpm-devel libXext-devel gmp gmp-devel libicu* icu* net-snmp-devel libpng-devel libjpeg-devel freetype-devel libxslt-devel sqlite…

真正意义上的数字零售,最为重要的一点就是要回归零售本身

互联网浪潮的退却并未真正将人们的思维带离互联网的牢笼,相反,越来越多的人依然在用互联网式的眼光看待后互联网时代的事物。尽管这样一种做法可以在一定程度上取得一定的效果,但是,如果仅仅只是用互联网思维来揣度这一切&#xf…

基于虚拟机机的代码保护技术

虚拟机保护技术是基于x86汇编系统的可执行代码转换为字节码指令系统的代码,以达到保护原有指令不被轻易逆向和篡改的目的。 字节码(Byte-code)是一种包含执行程序,由一序列 op 代码/数据对组成的 ,是一种中间码。字节是…

《第一行代码》 第五章:详解广播机制

如果你了解网络通信原理应该会知道,在一个 IP 网络范围中最大的IP 地址是被保留作为广播地址来使用的。比如某个网络的 IP 范围是 192.168.0XXX,子网掩码是255.255.255.0那么这个网络的广播地址就是 192.168.0255广播数据包会被发送到同-网络上的所有端口…

Spring Security OAuth2四种授权模式总结(七)

写在前面:各位看到此博客的小伙伴,如有不对的地方请及时通过私信我或者评论此博客的方式指出,以免误人子弟。多谢!如果我的博客对你有帮助,欢迎进行评论✏️✏️、点赞👍👍、收藏⭐️⭐️&#…

【MySQL】MySQL表的增删改查(CRUD)

✨个人主页:bit me👇 ✨当前专栏:MySQL数据库👇 ✨算法专栏:算法基础👇 ✨每日一语:生命久如暗室,不碍朝歌暮诗 目 录🔓一. CRUD🔒二. 新增(Creat…

将array中元素四舍五入取整的np.rint()方法

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 将array中元素四舍五入取整 np.rint()方法 选择题 关于以下python代码说法错误的一项是? import numpy as np a np.array([-1.7, 1.5, -0.2, 0.3]) print("【显示】a\n",a) pr…

BI是报表?BI是可视化?BI到底是什么?

很多企业认为只要买一个前端商业智能BI分析工具就可以解决企业级的商业智能BI所有问题,这个看法实际上也不可行的。可能在最开始分析场景相对简单,对接数据的复杂度不是很高的情况下这类商业智能BI分析工具没有问题。但是在企业的商业智能BI项目建设有一…

数字化系统使用率低的原因剖析

当“数字化变革”成为热门话题,当“数字化转型”作为主题频频出现在一个个大型会议中,我们知道数字化时代的确到来了。但是,根据Gartner的报告我们看到一个矛盾的现象——85%的企业数字化建设与应用并不理想、但对数字化系统的需求多年来持续…

软件测试项目实战(附全套实战项目教程+视频+源码)

开通博客以来,我更新了很多实战项目,但一部分小伙伴在搭建环境时遇到了问题。 于是,我收集了一波高频问题,汇成本篇,供大家参考,避免重复踩坑。 如果你还遇到过其他坑和未解决的问题,可在评论区…