【NLP练习】中文文本分类-Pytorch实现

news2025/1/6 19:57:30

中文文本分类-Pytorch实现

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、准备工作

1. 任务说明

本次使用Pytorch实现中文文本分类。主要代码与文本分类代码基本一致,不同的是本次任务使用了本地的中文数据,数据示例如下:
在这里插入图片描述

2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")   #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cpu')
import pandas as pd
 #加载自定义中文数据
train_data = pd.read_csv('./train.csv',sep='\t',header = None)
train_data.head()

输出:
在这里插入图片描述

#构造数据集迭代器
def coustom_data_iter(texts,labels):
 for x,y in zip(texts,labels):
  yield x,y
 
train_iter =coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

二、数据预处理

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter),
                                 specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
label_name = list(set(train_data[1].values[:]))
print(label_name)

输出:

['FilmTele-Play', 'Alarm-Update', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'Travel-Query', 'Music-Play', 'Video-Play', 'HomeAppliance-Control', 'Calendar-Query', 'TVProgram-Play', 'Other']
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
7

lambda表达式的语法为:lambda arguments: expression
其中arguments是函数的参数,可以有多个参数,用逗号分隔。expression是一个表达式,它定义了函数的返回值。

  • text_pipeline函数: 将原始文本数据转换为整数列表,使用了之前构建的vocab词表和tokenizer分词器函数。具体步骤:
  1. 接受一个字符串x作为输入
  2. 使用tokenizer将其分词
  3. 将每个词在vocab词表中的索引放入一个列表返回
  • label_pipeline函数: 将原始标签数据转换为整数,它接受一个字符串x作为输入,并使用 label_index.index(x) 方法获取x在label_name列表中的索引作为输出。

2.生成数据批次和迭代器

#生成数据批次和迭代器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [],[],[0]         
    for(_text, _label) in batch:
        #标签列表
        label_list.append(label_pipeline(_label))
        #文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        #偏移量
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和
    return text_list.to(device), label_list.to(device), offsets.to(device)

#数据加载器
dataloader = DataLoader(
    train_iter,
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_batch
)

三、模型构建

1. 搭建模型

#搭建模型
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel,self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小
                                        embed_dim,        # 嵌入的维度
                                        sparse=False)     #
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

2. 初始化模型

#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练与评估函数

#定义训练与评估函数
import time

def train(dataloader):
    model.train()          #切换为训练模式
    total_acc, train_loss, total_count = 0,0,0
    log_interval = 50
    start_time = time.time()
    for idx, (text,label, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()                             #grad属性归零
        loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真
        loss.backward()                                   #反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪
        optimizer.step()                                  #每一步自动更新
        
        #记录acc与loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(
                epoch,
                idx,
                len(dataloader),
                total_acc/total_count,
                train_loss/total_count))
            total_acc,train_loss,total_count = 0,0,0
            staet_time = time.time()

def evaluate(dataloader):
    model.eval()      #切换为测试模式
    total_acc,train_loss,total_count = 0,0,0
    with torch.no_grad():
        for idx,(text,label,offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label,label)   #计算loss值
            #记录测试数据
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    
    return total_acc/total_count, train_loss/total_count

四、训练模型

1. 拆分数据集并运行模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset

# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training

#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None

# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,
                                         [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

                                           
train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    #获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(
        epoch,
        time.time() - epoch_start_time,
        val_acc,
        val_loss))
    print('-' * 69)

输出:

['还有双鸭山到淮阴的汽车票吗13号的' '从这里怎么回家' '随便播放一首专辑阁楼里的佛里的歌' ...
 '黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席' '百事盖世群星星光演唱会有谁' '下周一视频会议的闹钟帮我开开']
|epoch1|  50/ 152 batches|train_acc0.953 train_loss0.00282
|epoch1| 100/ 152 batches|train_acc0.953 train_loss0.00271
|epoch1| 150/ 152 batches|train_acc0.952 train_loss0.00292
---------------------------------------------------------------------
| epoch 1 | time:5.50s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.961 train_loss0.00231
|epoch2| 100/ 152 batches|train_acc0.967 train_loss0.00204
|epoch2| 150/ 152 batches|train_acc0.963 train_loss0.00228
---------------------------------------------------------------------
| epoch 2 | time:5.06s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.975 train_loss0.00173
|epoch3| 100/ 152 batches|train_acc0.973 train_loss0.00177
|epoch3| 150/ 152 batches|train_acc0.972 train_loss0.00166
---------------------------------------------------------------------
| epoch 3 | time:5.07s | valid_acc 0.948 valid_loss 0.003
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.984 train_loss0.00137
|epoch4| 100/ 152 batches|train_acc0.987 train_loss0.00123
|epoch4| 150/ 152 batches|train_acc0.983 train_loss0.00119
---------------------------------------------------------------------
| epoch 4 | time:5.07s | valid_acc 0.950 valid_loss 0.003
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.985 train_loss0.00125
|epoch5| 100/ 152 batches|train_acc0.987 train_loss0.00119
|epoch5| 150/ 152 batches|train_acc0.986 train_loss0.00120
---------------------------------------------------------------------
| epoch 5 | time:5.03s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.985 train_loss0.00118
|epoch6| 100/ 152 batches|train_acc0.989 train_loss0.00114
|epoch6| 150/ 152 batches|train_acc0.985 train_loss0.00120
---------------------------------------------------------------------
| epoch 6 | time:5.40s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.984 train_loss0.00119
|epoch7| 100/ 152 batches|train_acc0.986 train_loss0.00119
|epoch7| 150/ 152 batches|train_acc0.989 train_loss0.00112
---------------------------------------------------------------------
| epoch 7 | time:5.71s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.985 train_loss0.00115
|epoch8| 100/ 152 batches|train_acc0.986 train_loss0.00128
|epoch8| 150/ 152 batches|train_acc0.989 train_loss0.00107
---------------------------------------------------------------------
| epoch 8 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.988 train_loss0.00114
|epoch9| 100/ 152 batches|train_acc0.983 train_loss0.00127
|epoch9| 150/ 152 batches|train_acc0.989 train_loss0.00109
---------------------------------------------------------------------
| epoch 9 | time:5.28s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.986 train_loss0.00115
|epoch10| 100/ 152 batches|train_acc0.987 train_loss0.00117
|epoch10| 150/ 152 batches|train_acc0.986 train_loss0.00119
---------------------------------------------------------------------
| epoch 10 | time:5.22s | valid_acc 0.949 valid_loss 0.003
---------------------------------------------------------------------
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

输出:

模型准确率为:0.9492

2. 测试指定数据

#测试指定的数据
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")

print("该文本的类别是: %s" %label_name[predict(ex_text_str,text_pipeline)])

输出:

该文本的类别是: Travel-Query

五、总结

训练神经网络时,可使用梯度裁剪的方法来防止梯度爆炸,使得模型训练更加稳定

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1570852.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[中级]软考_软件设计_计算机组成与体系结构_07_存储系统

存储系统 层次划存储概念图局促性原理分类存储器位置存取方式按内容存储按地址存储 工作方式拓展 往年真题 高速缓存(cache)概念案例解析&#xff1a;求取平均时间 Cache与主存的地址映射映像往年真题 主存编制计算编址大小的求取编址与计算存储单元编址内容总容量求取例题解析…

c# wpf template itemtemplate+dataGrid

1.概要 2.代码 <Window x:Class"WpfApp2.Window8"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expression/blend…

[C#]OpenCvSharp使用帧差法或者三帧差法检测移动物体

关于C版本帧差法可以参考博客 [C]OpenCV基于帧差法的运动检测-CSDN博客https://blog.csdn.net/FL1768317420/article/details/137397811?spm1001.2014.3001.5501 我们将参考C版本转成opencvsharp版本。 帧差法&#xff0c;也叫做帧间差分法&#xff0c;这里引用百度百科上的…

【力扣每日一题】1026. 节点与其祖先之间的最大差值

LC 1026. 节点与其祖先之间的最大差值 题目描述 给定二叉树的根节点 root&#xff0c;找出存在于 不同 节点 A 和 B 之间的最大值 V&#xff0c;其中 V |A.val - B.val|&#xff0c;且 A 是 B 的祖先。 &#xff08;如果 A 的任何子节点之一为 B&#xff0c;或者 A 的任何子…

https证书申请方式

网站HTTPS证书&#xff0c;也称为SSL证书或TLS证书&#xff0c;是一种数字证书&#xff0c;用于在用户浏览器与网站服务器之间建立安全的加密连接。当网站安装了HTTPS证书后&#xff0c;用户访问该网站时&#xff0c;浏览器地址栏会显示为"https://"开头&#xff0c;…

CSS层叠样式表学习(文本属性)

&#xff08;大家好&#xff0c;今天我们将继续来学习CSS文本属性的相关知识&#xff0c;大家可以在评论区进行互动答疑哦~加油&#xff01;&#x1f495;&#xff09; 目录 四、CSS文本属性 4.1 文本颜色 4.2 对齐文本 4.3 装饰文本 4.4 文本缩进 4.5 行间距 4.6 文本…

简单的购物商城

SSM整合后的一个及其简单的商城&#xff0c;首页数据是模拟的&#xff0c;主要测试购物车模块 启动 创建数据库&#xff1a;shopping导入建表脚本&#xff1a;shopping.sql修改db.properties部署和启动项目&#xff08;项目的path为项目名&#xff09;访问 http://localhost:…

Python语言在地球科学领域中的应用

Python是功能强大、免费、开源&#xff0c;实现面向对象的编程语言&#xff0c;Python能够运行在Linux、Windows、Macintosh、AIX操作系统上及不同平台&#xff08;x86和arm&#xff09;&#xff0c;Python简洁的语法和对动态输入的支持&#xff0c;再加上解释性语言的本质&…

猫头虎技术分享 || 断网了,还能ping127.0.0.1吗?

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

Shell GPT:直接安装使用的chatgpt应用软件

ShellGPT是一款基于预训练生成式Transformer模型&#xff08;如GPT系列&#xff09;构建的智能Shell工具。它将先进的自然语言处理能力集成到Shell环境中&#xff0c;使用户能够使用接近日常对话的语言来操作和控制操作系统。 官网&#xff1a;GitHub - akl7777777/ShellGPT: *…

liteIDE自定义主题推荐

代码编辑器配色 \liteidex38.3-win64-qt5.15.2\liteide\share\liteide\liteeditor\color <?xml version"1.0" encoding"UTF-8"?> <style-scheme version"1.0" name"Sublime Text 2"><style name"Text" f…

WebGL BabylonJS GUI 如何创建连接模型的按钮

如图所示&#xff1a; 方法&#xff1a; createGUI(mesh: BABYLON.Mesh, title: string, index: number) {const advancedTexture AdvancedDynamicTexture.CreateFullscreenUI(UI)const rect new Rectangle()rect.width 100pxrect.height 40pxrect.thickness 0advancedT…

MyBatis 使用入门

1. 什么是MyBatis MyBatis是一款持久层框架&#xff0c;用于简化JDBC的开发&#xff08;持久层指的就是持久化操作的层&#xff0c;通常指数据访问层&#xff08;dao&#xff09;&#xff0c;即用于操作数据库&#xff09;&#xff0c;简单来说MyBatis 是更简单完成程序和数据…

C++入门4.引用

目录 1.引用概念&#xff1a; 2.引用特性&#xff1a; 3.常引用&#xff1a; 4.使用场景&#xff1a; 引用和指针的区别&#xff1a; 1.引用概念&#xff1a; 引用不是新定义一个变量&#xff0c;而是给已存在变量取了一个别名&#xff0c;编译器不会为引用变量开辟内存空…

【C++】模拟实现红黑树(插入)

目录 红黑树的概念 红黑树的性质 红黑树的调整情况 红黑树的模拟实现 枚举类型的定义 红黑树节点的定义 插入函数的实现 旋转函数的实现 左旋 右旋 自检函数的实现 红黑树类 红黑树的概念 红黑树&#xff0c;是一种二叉搜索树&#xff0c;但在每个结点上增加一个存储…

详解protected访问限定符

1.同一个包中的同一类 package demo1;public class Test1 {protected int a 10;protected void b() {System.out.println("这是protected修饰的成员方法");}public static void main(String[] args) {Test1 test new Test1();System.out.println(test.a);test.b()…

燃气管网安全运行监测系统功能介绍

燃气管网&#xff0c;作为城市基础设施的重要组成部分&#xff0c;其安全运行直接关系到居民的生命财产安全和城市的稳定发展。然而&#xff0c;随着城市规模的不断扩大和燃气使用量的增加&#xff0c;燃气管网的安全运行面临着越来越大的挑战。为了应对这些挑战&#xff0c;燃…

华媒舍:3个科学指导,协助油管大V写下爆款文章

油管&#xff08;YouTube&#xff09;作为一个重要的视频分享平台&#xff0c;吸引了很多的观众和原创者。作为一位油管大V&#xff0c;你可能会一直在努力提升自己的文章质量以吸引更多的观众和订阅者。下面我们就为您提供三个科学指导&#xff0c;帮助自己写下更具有爆品发展…

可视化图表组件:仪表盘,监控数据关键指标信息,海量示例。

仪表盘组件&#xff08;Dashboard Component&#xff09;是一种常见的可视化设计组件&#xff0c;用于展示和监控数据的关键指标和信息。它通常以仪表盘的形式呈现&#xff0c;类似于汽车仪表盘&#xff0c;可以通过各种图表、指示器和控件来展示数据&#xff0c;并提供交互和操…

Linux安装JDK17等通用教程

一、查看Linux系统是否有自带的jdk 1、查看当前是否有jdk版本&#xff0c;命令如下&#xff1a; java -version2、检测jdk的安装包&#xff0c;命令如下&#xff1a; rpm -qa | grep java3、接着进行一个个删除包&#xff0c;命令如下&#xff1a; rpm -e --nodeps 包名4、…