第N4周:中文文本分类-Pytorch实现

news2025/2/22 20:09:15
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

一、准备工作

1.任务说明

 文本分类流程图:

 2.加载数据

​编辑 二、数据的预处理

1.构建词典

2.生成数据批次和迭代器

三、模型构建

四、训练模型

五、小结


一、准备工作

1.任务说明

本次将使用PyTorch实现中文文本分类。主要代码与N1周基本一致,不同的是本次任务中使用了本地的中文数据,数据示例如下:

本周任务:

1.学习如何进行中文本文预处理

2.根据文本内容(第1列)预测文本标签(第2列)

进阶任务:

1.尝试根据第一周的内容独立实现,尽可能的不看本文的代码

2.构建更复杂的网络模型,将准确率提升至91% 

 文本分类流程图:

 2.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")   #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import pandas as pd
 #加载自定义中文数据
train_data = pd.read_csv('./train.csv',sep='\t',header = None)
#构造数据集迭代器
def coustom_data_iter(texts,labels):
 for x,y in zip(texts,labels):
  yield x,y
 
train_iter =coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

输出: 

 二、数据的预处理

1.构建词典

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter),
                                 specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
label_name = list(set(train_data[1].values[:]))
print(label_name)
text_pipeline = lambda x : vocab(tokenizer(x))
label_pipeline = lambda x : label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

2.生成数据批次和迭代器

#生成数据批次和迭代器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [],[],[0]         
    for(_text, _label) in batch:
        #标签列表
        label_list.append(label_pipeline(_label))
        #文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        #偏移量
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和
    return text_list.to(device), label_list.to(device), offsets.to(device)

#数据加载器
dataloader = DataLoader(
    train_iter,
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_batch
)

三、模型构建

#搭建模型
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel,self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小
                                        embed_dim,        # 嵌入的维度
                                        sparse=False)     #
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)
#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)
#定义训练与评估函数
import time

def train(dataloader):
    model.train()          #切换为训练模式
    total_acc, train_loss, total_count = 0,0,0
    log_interval = 50
    start_time = time.time()
    for idx, (text,label, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()                             #grad属性归零
        loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真
        loss.backward()                                   #反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪
        optimizer.step()                                  #每一步自动更新
        
        #记录acc与loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(
                epoch,
                idx,
                len(dataloader),
                total_acc/total_count,
                train_loss/total_count))
            total_acc,train_loss,total_count = 0,0,0
            staet_time = time.time()

def evaluate(dataloader):
    model.eval()      #切换为测试模式
    total_acc,train_loss,total_count = 0,0,0
    with torch.no_grad():
        for idx,(text,label,offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label,label)   #计算loss值
            #记录测试数据
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    
    return total_acc/total_count, train_loss/total_count

四、训练模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset

# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training

#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None

# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,
                                         [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

                                           
train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    #获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(
        epoch,
        time.time() - epoch_start_time,
        val_acc,
        val_loss))
    print('-' * 69)
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))
#测试指定的数据
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")

print("该文本的类别是: %s" %label_name[predict(ex_text_str,text_pipeline)])

输出:

|epoch1|  50/ 152 batches|train_acc0.431 train_loss0.03045
|epoch1| 100/ 152 batches|train_acc0.700 train_loss0.01936
|epoch1| 150/ 152 batches|train_acc0.768 train_loss0.01370
---------------------------------------------------------------------
| epoch 1 | time:1.58s | valid_acc 0.789 valid_loss 0.012
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.818 train_loss0.01030
|epoch2| 100/ 152 batches|train_acc0.831 train_loss0.00932
|epoch2| 150/ 152 batches|train_acc0.850 train_loss0.00811
---------------------------------------------------------------------
| epoch 2 | time:1.47s | valid_acc 0.837 valid_loss 0.008
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.870 train_loss0.00688
|epoch3| 100/ 152 batches|train_acc0.887 train_loss0.00658
|epoch3| 150/ 152 batches|train_acc0.893 train_loss0.00575
---------------------------------------------------------------------
| epoch 3 | time:1.46s | valid_acc 0.866 valid_loss 0.007
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.906 train_loss0.00507
|epoch4| 100/ 152 batches|train_acc0.918 train_loss0.00468
|epoch4| 150/ 152 batches|train_acc0.915 train_loss0.00478
---------------------------------------------------------------------
| epoch 4 | time:1.47s | valid_acc 0.886 valid_loss 0.006
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.938 train_loss0.00378
|epoch5| 100/ 152 batches|train_acc0.935 train_loss0.00379
|epoch5| 150/ 152 batches|train_acc0.932 train_loss0.00376
---------------------------------------------------------------------
| epoch 5 | time:1.51s | valid_acc 0.890 valid_loss 0.006
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.951 train_loss0.00310
|epoch6| 100/ 152 batches|train_acc0.952 train_loss0.00287
|epoch6| 150/ 152 batches|train_acc0.950 train_loss0.00289
---------------------------------------------------------------------
| epoch 6 | time:1.50s | valid_acc 0.894 valid_loss 0.006
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.963 train_loss0.00233
|epoch7| 100/ 152 batches|train_acc0.963 train_loss0.00244
|epoch7| 150/ 152 batches|train_acc0.965 train_loss0.00222
---------------------------------------------------------------------
| epoch 7 | time:1.49s | valid_acc 0.898 valid_loss 0.005
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.975 train_loss0.00183
|epoch8| 100/ 152 batches|train_acc0.976 train_loss0.00176
|epoch8| 150/ 152 batches|train_acc0.971 train_loss0.00188
---------------------------------------------------------------------
| epoch 8 | time:1.67s | valid_acc 0.900 valid_loss 0.005
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.982 train_loss0.00145
|epoch9| 100/ 152 batches|train_acc0.982 train_loss0.00139
|epoch9| 150/ 152 batches|train_acc0.980 train_loss0.00141
---------------------------------------------------------------------
| epoch 9 | time:2.05s | valid_acc 0.901 valid_loss 0.006
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.990 train_loss0.00108
|epoch10| 100/ 152 batches|train_acc0.984 train_loss0.00119
|epoch10| 150/ 152 batches|train_acc0.986 train_loss0.00105
---------------------------------------------------------------------
| epoch 10 | time:1.98s | valid_acc 0.900 valid_loss 0.005
---------------------------------------------------------------------
模型准确率为:0.8996
该文本的类别是: Travel-Query

五、小结

  • 数据加载
    • 定义一个生成器函数,将文本和标签成对迭代。这是为了后续的数据处理和加载做准备。
  • 分词与词汇表
    • 使用jieba进行中文分词,jieba.lcut可以将中文文本切割成单个词语列表。
    • 使用torchtextbuild_vocab_from_iterator从分词后的文本中构建词汇表,并设置默认索引为<unk>,表示未知词汇。这对处理未见过的词汇非常重要。
  • 数据管道:创建文本和标签处理管道。
    • 创建两个处理管道:

    • text_pipeline:将文本转换为词汇表中的索引。
    • label_pipeline:将标签转换为索引。
  • 模型构建:定义带嵌入层和全连接层的文本分类模型。
    • 定义一个文本分类模型TextClassificationModel,包括一个嵌入层nn.EmbeddingBag和一个全连接层nn.Linearnn.EmbeddingBag在处理变长序列时性能较好,因为它不需要明确的填充操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1824978.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第2章 Rust初体验7/8:错误处理时不关心具体错误类型的下划线:提高代码可读性:猜骰子冷热游戏

讲动人的故事,写懂人的代码 2.6.6 用as进行类型转换:显式而简洁的语法 贾克强:“大家在查看Rust代码时,可能会注意到这一句。在这里,如果我们不使用as i32,编译器会报错,因为它在u32中找不到abs()方法。这是因为prev和sum_of_two_dice都是u32类型,u32类型并不支持abs(…

【LLM】吴恩达『微调大模型』课程完全笔记

Finetuning Large Language Models 版权说明&#xff1a; 『Finetuning Large Language Models』是DeepLearning.AI出品的免费课程&#xff0c;版权属于DeepLearning.AI(https://www.deeplearning.ai/)。 本文是对该课程内容的翻译整理&#xff0c;只作为教育用途&#xff0c;不…

4-字符串-11-反转字符串-LeetCode344

4-字符串-11-反转字符串-LeetCode344 LeetCode: 题目序号344 更多内容欢迎关注我&#xff08;持续更新中&#xff0c;欢迎Star✨&#xff09; Github&#xff1a;CodeZeng1998/Java-Developer-Work-Note 技术公众号&#xff1a;CodeZeng1998&#xff08;纯纯技术文&#xff0…

【牛客面试必刷TOP101】Day32.BM68 矩阵的最小路径和和BM69 把数字翻译成字符串

文章目录 前言一、BM68 矩阵的最小路径和题目描述题目解析二、BM69 把数字翻译成字符串题目描述题目解析总结 前言 一、BM68 矩阵的最小路径和 题目描述 描述&#xff1a; 给定两个字符串str1和str2&#xff0c;输出两个字符串的最长公共子序列。如果最长公共子序列为空&#x…

Ubuntu系统环境配置

Ubuntu安装terminator以及美化 安装 sudo apt-get install terminator美化 修改或者创建.config/terminator/config文件&#xff0c;添加如下配置 [global_config]suppress_multiple_term_dialog Truetitle_font Sans 11title_use_system_font False [keybindings] [lay…

supOS助力中核陕西铀浓缩有限公司迈向智能化、数字化、绿色化

中核陕西铀浓缩有限公司是中国核工业集团有限公司所属大型生产骨干企业&#xff0c;建于1969年10月&#xff0c;是我国第一座采用离心法生产浓缩铀的工厂。 蓝卓基于supOS工业操作系统&#xff0c;以“三化一平台一基础”的顶层架构&#xff0c;面向陕铀公司工艺优化、设备管理…

【Photoshop】PS修改文字内容

Photoshop(PS)修改图片上文字内容&#xff0c;网上教材不少&#xff0c;本人整理实践过的方法&#xff0c;分享给各位。本人实践方法&#xff1a; 内容识别填充&#xff1a;适用于背景色复杂的图片内容修补工具&#xff1a;适用于背景色为纯色的图片 方式一&#xff1a;内容识…

el-table表头文字换行或者修改字体颜色样式

例如 <el-table:data"tableData":header-cell-style"headClass" style"width: 100%;" border ><el-table-columnprop"address"label"生产工序"align"center"></el-table-column> //重点看这里…

uniapp 开发版小程序之间跳转

uni.navigateToMiniProgram({appId: urL,path: patH,envVersion: release,//我使用develop会给我返回&#xff1a;开发版小程序已过期&#xff0c;请在开发者工具重新扫码确定success(res) {console.log(res);// 打开成功uni.showToast({title: 跳转成功})},fail(err) {console…

希亦、追觅、云鲸洗地机:究竟有何不同?选择哪款更合适

最近收到很多私信里&#xff0c;要求洗地机测评的呼声特别高&#xff0c;作为宠粉的测评博主&#xff0c;当然是马上安排起来&#xff0c;满足大家对想看洗地机的愿望。这次洗地机测评&#xff0c;我挑选了三款热门的品牌型号&#xff0c;并从多个维度对它们进行使用测评&#…

【STM32】CubeIDE下载安装使用全记录

文章目录 0 前言1 下载安装2 基本使用2.0 编译下载2.1 字体和代码高亮设置2.2 快速格式化代码2.3 快速定位函数/变量的声明和定义2.4 设置代码折叠2.5 生成hex文件 3 设置代码自动提示4 设置中文界面5 遇到的问题和解决办法 0 前言 作为ST官方主推的集成开发环境&#xff08;ID…

深入浅出 Go 语言的 GPM 模型(Go1.21)

引言 在现代软件开发中&#xff0c;有效地利用并发是提高应用性能和响应速度的关键。随着多核处理器的普及&#xff0c;编程语言和框架如何高效、简便地支持并发编程&#xff0c;成为了软件工程师们评估和选择工具时的一个重要考量。在这方面&#xff0c;Go 语言凭借其创新的并…

单调栈——AcWing.830单调栈

单调栈 定义 单调栈是一种特殊的数据结构&#xff0c;栈内元素保持某种单调性&#xff08;通常是单调递增或单调递减&#xff09;。 运用情况 求解下一个更大元素或下一个更小元素。计算每个元素左边或右边第一个比它大或小的元素。 注意事项 要明确单调栈是递增还是递减…

七、IP路由原理和路由引入

目录 一、IP路由原理 二、路由引入 2.1、双点双向路由引入 2.2、路由回灌 三、路由策略与路由控制 路由匹配工具&#xff08;规则&#xff09; ACL IP前缀列表 路由控制工具&#xff08;控制&#xff09; 策略工具1 策略工具2 搭配组合 组…

CCAA质量管理【学习笔记】​ 备考知识点笔记(一)

第一部分 质量管理体系相关标准 《质量管理体系基础考试大纲》中规定的考试内容&#xff1a; 3.1质量管理体系标准 a) 了解 ISO 9000 系列标准发展概况&#xff1b; b) 理 解 GB/T19000 标准中涉及的基本概念和质量管理原则&#xff1b; c) 理 解GB/T19000 标准中的部分…

mybatis中SQL语句运用总结

union 连接查询 连接两个表后会过滤掉重复的值 <resultMap id"BaseResultMap" type"com.sprucetec.pay.etl.model.BillDetail"><id column"id" jdbcType"INTEGER" property"id"/><result column"pay_…

Docker中部署Jenkins+Pipline流水线基础语法入门

场景 DockerCompose中部署Jenkins&#xff08;Docker Desktop在windows上数据卷映射&#xff09;&#xff1a; DockerCompose中部署Jenkins&#xff08;Docker Desktop在windows上数据卷映射&#xff09;-CSDN博客 DockerComposeJenkinsPipeline流水线打包SpringBoot项目(解…

深度剖析淘宝扭蛋机源码:打造趣味性电商活动的秘诀

在当今电商市场中&#xff0c;如何吸引用户的注意力、提升用户的参与度成为了各大电商平台竞相追求的目标。淘宝扭蛋机作为一种新型的电商活动形式&#xff0c;以其趣味性和互动性深受用户喜爱。本文将深度剖析淘宝扭蛋机源码&#xff0c;探讨其如何打造趣味性与互动性并存的电…

数仓建模—OLTP 和 OLAP

数仓建模—OLTP 和 OLAP 前面我们在数仓建模—数仓初识 中提到了OLTP 和 OLAP 两个概念 OLAP 是 On-Line Analytical Processing(联机分析处理)的缩写。广义的 OLAP 泛指数据查询分析,像报表、即席查询、多维分析都属于 OLAP 的范畴。 OLTP 和 OLAP 最大区别在于前者会产…

C++入门 vector介绍及使用

目录 vector的介绍及使用 vector常用接口的介绍及使用 vector的定义 vector iterator 的使用 vector 空间增长问题 vector 增删查改 push_back/pop_back insert & erase & find operator[ ]的遍历 vector的介绍及使用 vector的文档介绍 vector是表示可变大…