深度学习每周学习总结N4:中文文本分类-Pytorch实现(基本分类(熟悉流程)、textCNN分类(通用模型)、Bert分类(模型进阶))

news2024/9/21 2:41:48
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

目录

    • 0. 总结:
    • 1. 基础模型
      • a. 数据加载
      • b. 数据预处理
      • c. 模型搭建与初始化
      • d. 训练函数
      • e. 评估函数
      • f.拆分数据集运行模型
      • g. 结果可视化
      • h. 测试指定数据
    • 2. TextCNN(通用模型-待拓展)
    • 3. Bert(高级模型-待拓展)

0. 总结:

之前有学习过文本预处理的环节,对文本处理的主要方式有以下三种:

1:词袋模型(one-hot编码)

2:TF-IDF

3:Word2Vec(词向量(Word Embedding) 以及Word2vec(Word Embedding 的方法之一))

详细介绍及中英文分词详见pytorch文本分类(一):文本预处理

上上期主要介绍Embedding,及EmbeddingBag 使用示例(对词索引向量转化为词嵌入向量) ,上期主要介绍:应用三种模型的英文分类

本期将主要介绍中文基本分类(熟悉流程)、拓展:textCNN分类(通用模型)、拓展:Bert分类(模型进阶)

1. 基础模型

a. 数据加载

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets

import numpy as np
import pandas as pd

import os,PIL,pathlib,warnings
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore") # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False   # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
# 加载自定义中文数据集
train_data = pd.read_csv('./data/N4-train.csv',sep = '\t',header = None)
train_data.head()
01
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
train_data[0]
0                  还有双鸭山到淮阴的汽车票吗13号的
1                            从这里怎么回家
2                   随便播放一首专辑阁楼里的佛里的歌
3                          给看一下墓王之王嘛
4              我想看挑战两把s686打突变团竞的游戏视频
                    ...             
12095          一千六百五十三加三千一百六十五点六五等于几
12096                      稍小点客厅空调风速
12097    黎耀祥陈豪邓萃雯畲诗曼陈法拉敖嘉年杨怡马浚伟等到场出席
12098                  百事盖世群星星光演唱会有谁
12099                 下周一视频会议的闹钟帮我开开
Name: 0, Length: 12100, dtype: object
# 构建数据迭代器
def custom_data_iter(texts,labels):
    for x,y in zip(texts,labels):
        yield x,y

train_iter = custom_data_iter(train_data[0].values[:],train_data[1].values[:])

b. 数据预处理

# 构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

# 中文分词方法
tokenizer = jieba.lcut # jieba.cut返回的是一个生成器,而jieba.lcut返回的是一个列表

def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)
        
vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\Cheng\AppData\Local\Temp\jieba.cache
Loading model cost 0.549 seconds.
Prefix dict has been built successfully.
tokenizer('我想看和平精英上战神必备技巧的游戏视频')
['我', '想', '看', '和平', '精英', '上', '战神', '必备', '技巧', '的', '游戏', '视频']
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
text_pipeline = lambda x:vocab(tokenizer(x))
label_pipeline = lambda x:label_name.index(x)
label_name = list(set(train_data[1].values[:]))
print(label_name)
['HomeAppliance-Control', 'Audio-Play', 'Other', 'Weather-Query', 'Music-Play', 'Travel-Query', 'TVProgram-Play', 'Alarm-Update', 'Video-Play', 'Calendar-Query', 'FilmTele-Play', 'Radio-Listen']

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
8
# 生成数据批次和迭代器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list,text_list,offsets = [],[],[0]
    
    for (_text,_label) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text),dtype = torch.int64)
        text_list.append(processed_text)
        # 偏移量(即语句的总词汇量)
        offsets.append(processed_text.size(0))
        
    label_list = torch.tensor(label_list,dtype = torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # 返回维度dim中输入元素的累计和
    
    return text_list.to(device),label_list.to(device),offsets.to(device)

# 数据加载器,调用示例
dataloader = DataLoader(train_iter,
                       batch_size = 8,
                       shuffle = False,
                       collate_fn = collate_batch)

c. 模型搭建与初始化

from torch import nn

class TextClassificationModel(nn.Module):
    
    def __init__(self,vocab_size,embed_dim,num_class):
        super(TextClassificationModel,self).__init__()
        
        self.embedding = nn.EmbeddingBag(vocab_size,  # 词典大小
                                         embed_dim,   # 嵌入维度
                                         sparse=False
                                        )
        self.fc = nn.Linear(embed_dim,num_class)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange,initrange) # 初始化权重
        self.fc.weight.data.uniform_(-initrange,initrange)
        self.fc.bias.data.zero_()
        
    def forward(self,text,offsets):
        embedded = self.embedding(text,offsets)
        return self.fc(embedded)
# 初始化模型
num_class = len(label_name)
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size,em_size,num_class).to(device)

d. 训练函数

import time

def train(dataloader):
    size = len(dataloader.dataset) # 训练集的大小
    num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)
    
    train_acc,train_loss = 0,0 # 初始化训练损失和正确率
    
    for idx,(text,label,offsets) in enumerate(dataloader):
        # 计算预测误差
        predicted_label = model(text,offsets)   # 网络输出
        loss = criterion(predicted_label,label) # 计算网络输出和真实值之间的差距,label为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad() # grad属性归零
        loss.backward() # 反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)
        optimizer.step() # 每一步自动更新
        
        # 记录acc与loss
        train_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss  += loss.item()
        
    train_acc /= size
    train_loss /= num_batches
    
    return train_acc,train_loss

e. 评估函数

import time

def evaluate(dataloader):
    test_acc,test_loss,total_count = 0,0,0
    
    with torch.no_grad():
        for idx,(text,label,offsets) in enumerate(dataloader):
            # 计算预测误差
            predicted_label = model(text,offsets)
            loss = criterion(predicted_label,label)
            
            # 记录测试数据
            test_acc += (predicted_label.argmax(1) == label).sum().item()
            test_loss += loss.item()
            total_count += label.size(0)
            
    return test_acc/total_count,test_loss/total_count

f.拆分数据集运行模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 10 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

# 构建数据集
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset,
                                          [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
import copy

train_acc = []
train_loss = []
test_acc = []
test_loss = []

best_acc = None # 设置一个最佳准确率,作为最佳模型的判别指标


for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    
    model.train() # 切换为训练模式
    epoch_train_acc,epoch_train_loss = train(train_dataloader)
    
    model.eval() # 切换为测试模式
    epoch_test_acc,epoch_test_loss = evaluate(valid_dataloader)
    
    if best_acc is not None and best_acc > epoch_test_acc:
        scheduler.step()
    else:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))
    
print('Done!')
Epoch: 1,Train_acc:63.8%,Train_loss:1.339,Test_acc:79.5%,Test_loss:0.012,Lr:5.00E+00
Epoch: 2,Train_acc:83.2%,Train_loss:0.592,Test_acc:84.8%,Test_loss:0.008,Lr:5.00E+00
Epoch: 3,Train_acc:88.2%,Train_loss:0.413,Test_acc:86.8%,Test_loss:0.007,Lr:5.00E+00
Epoch: 4,Train_acc:91.1%,Train_loss:0.313,Test_acc:87.6%,Test_loss:0.006,Lr:5.00E+00
Epoch: 5,Train_acc:93.3%,Train_loss:0.241,Test_acc:89.8%,Test_loss:0.006,Lr:5.00E+00
Epoch: 6,Train_acc:95.0%,Train_loss:0.189,Test_acc:89.8%,Test_loss:0.006,Lr:5.00E-01
Epoch: 7,Train_acc:96.7%,Train_loss:0.144,Test_acc:89.6%,Test_loss:0.006,Lr:5.00E-02
Epoch: 8,Train_acc:96.8%,Train_loss:0.139,Test_acc:89.6%,Test_loss:0.006,Lr:5.00E-03
Epoch: 9,Train_acc:96.8%,Train_loss:0.139,Test_acc:89.6%,Test_loss:0.006,Lr:5.00E-04
Epoch:10,Train_acc:96.8%,Train_loss:0.139,Test_acc:89.6%,Test_loss:0.005,Lr:5.00E-05
Done!

g. 结果可视化

epochs_range = range(EPOCHS)

plt.figure(figsize=(12,3))
plt.subplot(1,2,1)

plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Train Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Loss')
plt.show()


在这里插入图片描述

h. 测试指定数据

注意以下俩种测试方法,必须保证模型和数据在同样的设备上(GPU或CPU)

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item()

# ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"

model = model.to("cpu")

print("该文本的类别是:%s" %label_name[predict(ex_text_str, text_pipeline)])
该文本的类别是:Travel-Query

def predict(text, text_pipeline, model, device):
    model.eval()  # 切换为评估模式
    with torch.no_grad():
        # 将文本和偏移量张量都移动到相同的设备
        text = torch.tensor(text_pipeline(text)).to(device)
        offsets = torch.tensor([0]).to(device)
        output = model(text, offsets)
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"

# 确保模型在设备上
model = model.to(device)

print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline, model, device)])
该文本的类别是:Travel-Query

2. TextCNN(通用模型-待拓展)


3. Bert(高级模型-待拓展)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1941246.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

git命令学习分享

分布式版本控制系统&#xff0c;本地仓库和远程仓库相互独立。 使用repository仓库进行控制&#xff0c;可以对里面的文件进行跟踪&#xff0c;复原。 git config --global --list&#xff1a;查看git配置列表 cd ** &#xff1a;进入** cd .. &#xff1a;退回上一级 echo…

【人工智能】Transformers之Pipeline(四):零样本音频分类(zero-shot-audio-classification)

​​​​​​​ 目录 一、引言 二、零样本音频分类&#xff08;zero-shot-audio-classification&#xff09; 2.1 概述 2.2 意义 2.3 应用场景 2.4 pipeline参数 2.4.1 pipeline对象实例化参数​​​​​​​ 2.4.2 pipeline对象使用参数 2.4 pipeline实战 2.5 模…

TinyVue:与 Vue 交往八年的组件库

本文由体验技术团队莫春辉老师原创~ 去年因故停办的 VueConf&#xff0c;今年如约在深圳举行。作为东道主 & 上届 VueConf 讲师的我&#xff0c;没有理由不来凑个热闹。大会结束后&#xff0c;我见裕波在朋友圈转发 Jinjiang 的文章《我和 Vue.js 的十年》&#xff0c;我就…

版本控制工具

版本控制工具是用于记录代码文件变化历史、方便查阅特定版本修改情况的系统&#xff0c;一般分为集中式和分布式两种。以下是一些常见的版本控制工具&#xff1a; 集中式版本控制工具 Subversion&#xff08;SVN&#xff09; 简介&#xff1a;Subversion是一种集中式版本控制…

【LeetCode】day15:110 - 平衡二叉树, 257 - 二叉树的所有路径, 404 - 左叶子之和, 222 - 完全二叉树的节点个数

LeetCode 代码随想录跟练 Day15 110.平衡二叉树257.二叉树的所有路径404.左叶子之和222.完全二叉树的节点个数 110.平衡二叉树 题目描述&#xff1a; 给定一个二叉树&#xff0c;判断它是否是 平衡二叉树 平衡二叉树的定义是&#xff0c;对于树中的每个节点&#xff0c;其左右…

文件包含漏洞: 函数,实例[pikachu_file_inclusion_local]

文件包含 文件包含是一种较为常见技术&#xff0c;允许程序员在不同的脚本或程序中重用代码或调用文件 主要作用和用途&#xff1a; 代码重用&#xff1a;通过将通用函数或代码段放入单独的文件中&#xff0c;可以在多个脚本中包含这些文件&#xff0c;避免重复编写相同代码。…

昇思25天学习打卡营第27天 | Diffusion扩散模型

学习心得&#xff1a;探索Diffusion扩散模型 在我最近对生成模型的学习中&#xff0c;尤其是Diffusion模型&#xff0c;我发现这是一种极具潜力的技术&#xff0c;特别是在图像生成领域。Diffusion模型的核心概念是通过一个逐步的去噪过程&#xff0c;将纯噪声数据转换成有意义…

算法——双指针(day4)

15.三数之和 15. 三数之和 - 力扣&#xff08;LeetCode&#xff09; 题目解析&#xff1a; 这道题目说是三数之和&#xff0c;其实这和我们之前做过的两数之和是一个规律的~无非就是我们需要实时改动target的值。先排好序&#xff0c;然后固定一个数取其负值作target&#xf…

单链表<数据结构 C版>

目录 概念 链表的单个结点 链表的打印操作 新结点的申请 尾部插入 头部插入 尾部删除 头部删除 查找 在指定位置之前插入数据 在任意位置之后插入数据 测试运行一下&#xff1a; 删除pos结点 删除pos之后结点 销毁链表 概念 单链表是一种在物理存储结构上非连续、非顺序…

Golang | Leetcode Golang题解之第264题丑数II

题目&#xff1a; 题解&#xff1a; func nthUglyNumber(n int) int {dp : make([]int, n1)dp[1] 1p2, p3, p5 : 1, 1, 1for i : 2; i < n; i {x2, x3, x5 : dp[p2]*2, dp[p3]*3, dp[p5]*5dp[i] min(min(x2, x3), x5)if dp[i] x2 {p2}if dp[i] x3 {p3}if dp[i] x5 {p5…

【PostgreSQL教程】PostgreSQL 选择数据库

博主介绍:✌全网粉丝20W+,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物联网、机器学习等设计与开发。 感兴趣的可…

图论模型-迪杰斯特拉算法和贝尔曼福特算法★★★★

该博客为个人学习清风建模的学习笔记&#xff0c;部分课程可以在B站&#xff1a;【强烈推荐】清风&#xff1a;数学建模算法、编程和写作培训的视频课程以及Matlab等软件教学_哔哩哔哩_bilibili 目录 ​1图论基础 1.1概念 1.2在线绘图 1.2.1网站 1.2.2MATLAB 1.3无向图的…

基于SpringBoot+Vue的校园疫情防控系统(带1w+文档)

基于SpringBootVue的校园疫情防控系统(带1w文档) 基于SpringBootVue的校园疫情防控系统(带1w文档) 主要对首页、个人中心、学生管理、疫情动态管理、知识信息管理、防疫教育管理、健康打卡管理、请假申请管理、出校登记管理、入校登记管理、核酸报告管理、交流论坛、系统管理的…

MySQL的建表及查询

一。建立表 mysql> create table student(id int(10) not null unique primary key,name varchar(20) not null,sex varchar(4),birth year,department varchar(20),address varchar(50)); mysql> create table score(id int(10) not null unique primary key auto_incr…

精明选择施工项目管理工具的实用建议

国内外主流的10款施工项目进度管理软件对比&#xff1a;PingCode、Worktile、Contractor Foreman、建设工程项目管理平台&#xff08;JSGC&#xff09;、智慧工地综合管理系统、工程项目信息管理系统&#xff08;GCXX&#xff09;、Buildertrend、Procore、Autodesk Constructi…

Edge侧边栏copilot消失

Edge侧边栏copilot消失 当前环境 自己ip问题已解决&#xff0c;edge中已登录账号&#xff0c;地区已设置为美国&#xff0c;语言已设置为英文。具体可以通过空白页右上角的setting验证 解决方案 首先&#xff0c;打开“任务管理器”&#xff0c;在其中找到 Microsoft Edge…

【C语言】动态内存管理(下)(realloc函数)

文章目录 前言1. realloc2. realloc函数在调整空间时的细节2.1 针对情况1&#xff08;realloc后面有足够的内存空间&#xff09;2.2 针对情况2&#xff08;realloc后面没有足够的内存空间&#xff09;2.3 realloc函数使用的注意事项2.4 realloc的使用实例2.5 realloc函数的补充…

ubuntu安装mysql8.0

文章目录 ubuntu版本安装修改密码取消root跳过密码验证 ubuntu版本 22.04 安装 更新软件包列表 sudo apt update安装 MySQL 8.0 服务器 sudo apt install mysql-server在安装过程中&#xff0c;系统可能会提示您设置 root 用户的密码&#xff0c;请务必牢记您设置的密码。…

产线中有MES系统 还有安装SCADA的必要吗?

MES系统即制造执行系统&#xff08;Manufacturing Execution System&#xff09;&#xff0c;是一种面向车间层的管理信息系统&#xff0c;旨在通过信息传递优化从订单下达到产品完成的全过程管理。 MES可以为企业提供包括制造数据管理、计划排程管理、生产调度管理、库存管理、…

网路布线和数值转换

文章目录 信号的分类数字信息的优势双绞线分类双绞线标准与分类 光纤的特点光纤分为单模光纤和多模光纤 光纤接口双绞线的连接规范EIA/TIA-568A和568B 线缆的连接综合布线系统无线电波的传输方式 数制转换十进制转二进制计算机的数值 信号的分类 1.模拟信号 2.数字信号 数字信…