第N6周:使用Word2vec实现文本分类

news2024/11/27 0:26:18
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
#忽略警告信息
warnings.filterwarnings("ignore")
# win10系统
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
device

import pandas as pd
# 加载自定义中文数据
train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
train_data.head()

# 构造数据集迭代器
def coustom_data_iter(texts,labels):
    for x,y in zip(texts,labels):
        yield x,y
x = train_data[0].values[:]
#多类标签的one-hot展开
y = train_data[1].values[:]


from gensim.models.word2vec import Word2Vec
import numpy as np
#训练word2Vec浅层神经网络模型
w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。
              ,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5

w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)


# 将文本转化为向量
def average_vec(text):
    vec =np.zeros(100).reshape((1,100))
    for word in text:
        try:
            vec +=w2v.wv[word].reshape((1,100))
        except KeyError:
                continue
    return vec
#将词向量保存为Ndarray
x_vec= np.concatenate([average_vec(z)for z in x])
#保存Word2Vec模型及词向量
w2v.save('data/w2v_model.pk1')


train_iter= coustom_data_iter(x_vec,y)
len(x),len(x_vec)

label_name =list(set(train_data[1].values[:]))
print(label_name)


text_pipeline =lambda x:average_vec(x)
label_pipeline =lambda x:label_name.index(x)

text_pipeline("你在干嘛")
label_pipeline("Travel-Query")


from torch.utils.data import DataLoader
def collate_batch(batch):
    label_list,text_list=[],[]
    for(_text,_label)in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)
        text_list.append(processed_text)
        label_list = torch.tensor(label_list,dtype=torch.int64)
        text_list = torch.cat(text_list)
    return text_list.to(device),label_list.to(device)
# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,
shuffle =False,
collate_fn=collate_batch)



from torch import nn
class TextclassificationModel(nn.Module):

    def __init__(self,num_class):
        super(TextclassificationModel,self).__init__()
        self.fc = nn.Linear(100,num_class)
    def forward(self,text):
        return self.fc(text)


num_class =len(label_name)
vocab_size =100000
em_size=12
model= TextclassificationModel(num_class).to(device)




import time
def train(dataloader):
    model.train()#切换为训练模式
    total_acc,train_loss,total_count =0,0,0
    log_interval=50
    start_time= time.time()

    for idx,(text,label)in enumerate(dataloader):
        predicted_label= model(text)
        # grad属性归零
        optimizer.zero_grad()
        loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,label
        loss.backward()
        #反向传播
        torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪
        optimizer.step()#每一步自动更新
        #记录acc与loss
        total_acc+=(predicted_label.argmax(1)==label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval==0 and idx>0:
            elapsed =time.time()-start_time
            print('Iepoch {:1d}I{:4d}/{:4d} batches'
            '|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))
            total_acc,train_loss,total_count =0,0,0

            start_time = time.time()
def evaluate(dataloader):
    model.eval()#切换为测试模式
    total_acc,train_loss,total_count =0,0,0
    with torch.no_grad():
        for idx,(text,label)in enumerate(dataloader):
            predicted_label= model(text)
            loss = criterion(predicted_label,label)# 计算loss值
            # 记录测试数据
            total_acc+=(predicted_label.argmax(1)== label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    return total_acc/total_count,train_loss/total_count




from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS=10#epoch
LR=5 #学习率
BATCH_SIZE=64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.SGD(model.parameters(),lr=LR)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None
# 构建数据集
train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS+1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc,val_loss = evaluate(valid_dataloader)
    # 获取当前的学习率
    lr =optimizer.state_dict()['param_groups'][0]['1r']
    if total_accu is not None and total_accu>val_acc:
      scheduler.step()
    else:
     total_accu = val_acc
    print('-'*69)
    print('|epoch {:1d}|time:{:4.2f}s |'
    'valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,
    time.time()-epoch_start_time,
    val_acc,val_loss,lr))

    print('-'*69)


# test_acc,test_loss =evaluate(valid_dataloader)
# print('模型准确率为:{:5.4f}'.format(test_acc))
#
#
# def predict(text,text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text),dtype=torch.float32)
#         print(text.shape)
#         output = model(text)
#         return output.argmax(1).item()
# # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
# ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
# model=model.to("cpu")
# print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

[[-0.85472693  0.96605204  1.5058695  -0.06065784 -2.10079319 -0.12021151
   1.41170089  2.00004494  0.90861696 -0.62710127 -0.62408304 -3.80595499
   1.02797993 -0.45584389  0.54715634  1.70490362  2.33389823 -1.99607518
   4.34822938 -0.76296186  2.73265275 -1.15046433  0.82106878 -0.32701646
  -0.50515595 -0.37742117 -2.02331601 -1.365334    1.48786476 -1.6394971
   1.59438308  2.23569647 -0.00500725 -0.65070192  0.07377997  0.01777986
  -1.35580809  3.82080549 -2.19764423  1.06595343  0.99296588  0.58972518
  -0.33535255  2.15471306 -0.52244038  1.00874437  1.28869729 -0.72208139
  -2.81094289  2.2614549   0.20799019 -2.36187895 -0.94019454  0.49448857
  -0.68613767 -0.79071895  0.47535057 -0.78339124 -0.71336574 -0.27931567
   1.0514895  -1.76352624  1.93158554 -0.85853558 -0.65540617  1.3612217
  -1.39405773  1.18187538  1.31730198 -0.02322496  0.14652854  0.22249881
   2.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.27722868
   2.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767  2.12093259
  -1.2326943  -1.89571355  2.3295732  -0.53244872 -0.67313893 -0.80814604
   0.86987564 -1.31373079  1.33797717  1.02223087  0.5817025  -0.83535647
   0.97088164  2.09045361 -2.57758138  0.07126901]]
6

输出结果并非为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1558527.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入探索Yarn:安装与使用指南

Yarn 是一个由 Facebook 开发的 JavaScript 包管理器,旨在提供更快、更可靠的包管理体验。它与 npm 类似,但在某些方面更加高效和可靠。本文将介绍如何安装 Yarn,并展示如何使用它来管理 JavaScript 项目的依赖。 1. 安装 Yarn Yarn 可以通…

vs2022 关于Python项目无法识别中文的解决方法

这是针对于vs2022安装和使用教程(详细)-CSDN博客 Python项目无法识别中文的解决方法的文章 一、问题 1.输入代码 print("你好Hello world!") 2.启动,发现代码里有中文报错 二、解决方法 1.选择菜单栏里的工具->…

阿里云服务器ECS经济型e和u1实例规格如何选择?

阿里云服务器u1和e实例有什么区别?ECS通用算力型u1实例是企业级独享型云服务器,ECS经济型e实例是共享型云服务器,所以相比较e实例,云服务器u1性能更好一些。e实例为共享型云服务器,共享型实例采用非绑定CPU调度模式&am…

线程的等待通知机制

线程的等待通知机制 一:情景再现:二:等待通知机制:2.1 wait()方法2.2 notify()方法2.22:唤醒了t2线程,t1线程仍处于阻塞等待状态2.23 唤醒了t1线程,t2线程仍处于阻塞等待状态2.24:notifyAll() 一:情景再现: 假设有3个滑稽,1号滑稽在ATM中取钱,2,3号滑稽只能在门口阻塞等待,1号…

VuePress基于 Vite 和 Vue 构建优秀框架

VitePress 是一个静态站点生成器 (SSG),专为构建快速、以内容为中心的站点而设计。简而言之,VitePress 获取用 Markdown 编写的内容,对其应用主题,并生成可以轻松部署到任何地方的静态 HTML 页面。 VitePress 附带一个用于技术文档…

Vmware下减小Ubuntu系统占用系统盘大小

1、虚拟机设置下占用空间 如图,给虚拟机分配了120GB,已经占用116.9GB,开机会提示空间不足。 2、实际使用空间 ubuntu系统下使用“df -h”命令查看实际使用空间大小50GB左右 造成这个原因是,虚拟机的bug:在虚拟机的ub…

【递归】有序分数(SBT)

给定一个整数 N,请你求出所有分母小于或等于 N,大小在 [0,1][0,1] 范围内的最简分数,并按从小到大顺序依次输出。 例如,当 N5时,所有满足条件的分数按顺序依次为: 0/1,1/5,1/4,1/3,2/5,1/2,3/5,2/3,3/4,4…

二叉树寻找祖先问题-算法通关村

二叉树寻找祖先问题-算法通关村 1 最近公共祖先问题 LeetCode236:给定一个二叉树,找到该树中两个指定节点的最近公共祖先。 最近公共祖先的定义为:“对于有根树T 的两个节点 p、q,最近公共祖先表示为一个节点 x,满足是…

Docket常见的软件部署1

1 安装MySQL # 查看MySQL镜像 docker search mysql # 拉起镜像 docker pull mysql:5.7 # 创建MySQL数据映射卷,防止数据不丢失 mkdir -p /hmoe/tem/docker/mysql/data/ # 启动镜像 docker run -d --name mysql -e MYSQL_ROOT_PASSWORD123456 -p 3306:3306 -v /home…

7_springboot_shiro_jwt_多端认证鉴权_自定义AuthenticationToken

1. 目标 ​ 本小节会先对Shiro的核心流程进行一次回顾,并进行梳理。然后会介绍如果应用是以API接口的方式提供给它方进行调用,那么在这种情况下如何使用Shiro框架来完成接口调用的认证和授权。 2. 核心架构 引用官方的架构图: 2.1 Subje…

蓝桥杯第十五届抱佛脚(八)并查集

蓝桥杯第十五届抱佛脚(八)并查集 基本概念 并查集是一种数据结构,用于管理一系列不交集的元素集合,并支持两种操作: 查找(Find): 查找操作用于确定某个元素属于哪个集合&#xf…

Topaz Photo AI for Mac v2.4.2 智能AI降噪软件

Topaz Photo AI是一款适用于Mac的图像处理软件,使用人工智能技术对照片进行编辑和优化。该软件提供了多种强大的功能,包括降噪、锐化、消除噪点、提高分辨率等,可以帮助用户改善图像质量,并实现自定义的效果。 软件下载&#xff1…

前端-html-02

1.列表 标签名功能和语义属性单标签还是双标签ul无序列表包裹元素双标签 ol 有序列表包裹元素双标签li列表项双标签dl定义列表包裹元素双标签dt定义列表项标题双标签dd定义列表项描述双标签 li必须由Ul或者ol包裹 <!DOCTYPE html> <html><head><…

Web APIs知识点讲解(阶段七)

正则表达式 1.能够利用正则表达式校验输入信息的合法性2. 具备利用正则表达式验证小兔鲜注册页面表单的能力 一.正则表达式 1.正则表达式 正则表达式&#xff08;Regular Expression&#xff09;是用于匹配字符串中字符组合的模式。在 JavaScript中&#xff0c;正则表达式也…

我们正在被 DDoS 攻击,但是我们啥也不干,随便攻击...

最近&#xff0c;一场激烈的攻防大战在网络世界悄然上演。 主角不是什么国家安全局或者黑客组织&#xff0c;而是一家名不见经传的创业公司——TablePlus。 DDoS 攻击者们摩拳擦掌&#xff0c;跃跃欲试。他们从四面八方蜂拥而至&#xff0c;誓要用数亿次请求把 TablePlus 的服…

Redis 常见数据结构及命令

目录 一.Redis常见的数据结构 二.Redis数据结构对应的命令 1.String类型 2.Hash类型 3.List类型 4.Set类型 5.Sorted Set类型 一.Redis常见的数据结构 Redis支持多种数据结构&#xff0c;包括字符串&#xff08;string&#xff09;、哈希&#xff08;hash&#xff09;、…

STM32的芯片无法在线调试的情况分析

问题描述 本博客的目的在于帮助网友尽快地解决问题&#xff0c; 避免浪费时间&#xff0c; 查漏补缺。 在stm32的开发过程中&#xff0c;有时会遇到"STM No Target connected"的错误提示&#xff0c;这说明MDK开发环境无法与目标设备进行通信&#xff0c;导致无法烧…

【JavaSE】类和对象详解(上)

欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 目录 类和对象 类的组成 对类的理解 成员变量的访问和类方法的调用 this 抛出一个问题 this的作用 初始化成员变量 未初始化的成员变量 代码举例 就地初始化 构…

Autodesk Maya 2025 mac玛雅三维动画特效软件

Autodesk Maya 2025 for Mac是一款功能强大、操作简便的三维动画软件&#xff0c;适用于电影、电视、游戏、建筑、工业设计、虚拟现实和动画等领域。无论是专业设计师还是初学者&#xff0c;都可以通过Maya 2025实现自己的创意和想法&#xff0c;创作出高质量的三维作品。 软件…

浅谈Spring体系的理解

浅谈Spring知识体系 Spring Framework架构图Spring家族技术生态全景图XMind汇总 本文不涉及细节&#xff0c;主要回答两个问题&#xff1a; Spring家族技术生态全景图有哪些Spring Framework架构下每个模块有哪些东西&#xff0c;以及部分模块之间的关联关系 Spring Framework架…