第N2周:中文文本分类-Pytorch实现

news2024/12/28 17:52:57

目录

  • 一、前言
  • 二、准备工作
  • 三、数据预处理
    • 1.加载数据
    • 2.构建词典
    • 3.生成数据批次和迭代器
  • 三、模型构建
    • 1. 搭建模型
    • 2. 初始化模型
    • 3. 定义训练与评估函数
  • 四、训练模型
    • 1. 拆分数据集并运行模型

一、前言

🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍖 原作者:K同学啊|接辅导、项目定制

● 难度:夯实基础⭐⭐
● 语言:Python3、Pytorch3
● 时间:4月23日-4月28日
🍺要求:
1、熟悉NLP的基础知识

二、准备工作

环境搭建
Python 3.8
pytorch == 1.8.1
torchtext == 0.9.1

三、数据预处理

1.加载数据

在这里插入图片描述

import torch
import torch.nn as nn
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

# win10系统
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
import pandas as pd

# 加载自定义中文数据
train_data = pd.read_csv('./data/train.csv', sep='\t', header=None)
train_data.head()
# 构造数据集迭代器
def coustom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y
        
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])

2.构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# conda install jieba -y
import jieba

# 中文分词方法
tokenizer = jieba.lcut

def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
label_name = list(set(train_data[1].values[:]))
print(label_name)
text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

3.生成数据批次和迭代器

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    
    for (_text,_label) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        
        # 偏移量,即语句的总词汇量
        offsets.append(processed_text.size(0))
        
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list  = torch.cat(text_list)
    offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和
    
    return text_list.to(device),label_list.to(device), offsets.to(device)

# 数据加载器,调用示例
dataloader = DataLoader(train_iter,
                        batch_size=8,
                        shuffle   =False,
                        collate_fn=collate_batch)

三、模型构建

1. 搭建模型

from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        
        self.embedding = nn.EmbeddingBag(vocab_size,   # 词典大小
                                         embed_dim,    # 嵌入的维度
                                         sparse=False) # 
        
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange) # 初始化权重
        self.fc.weight.data.uniform_(-initrange, initrange)        
        self.fc.bias.data.zero_()                                  # 偏置值归零

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

2. 初始化模型

num_class  = len(label_name)
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练与评估函数

import time

def train(dataloader):
    model.train()  # 切换为训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 50
    start_time   = time.time()

    for idx, (text,label,offsets) in enumerate(dataloader):
        
        predicted_label = model(text, offsets)
        
        optimizer.zero_grad()                    # grad属性归零
        loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward()                          # 反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪
        optimizer.step()  # 每一步自动更新
        
        # 记录acc与loss
        total_acc   += (predicted_label.argmax(1) == label).sum().item()
        train_loss  += loss.item()
        total_count += label.size(0)
        
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count, train_loss/total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (text,label,offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            
            loss = criterion(predicted_label, label)  # 计算loss值
            # 记录测试数据
            total_acc   += (predicted_label.argmax(1) == label).sum().item()
            train_loss  += loss.item()
            total_count += label.size(0)
            
    return total_acc/total_count, train_loss/total_count

四、训练模型

1. 拆分数据集并运行模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 10 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset,
                                          [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,
                                           time.time() - epoch_start_time,
                                           val_acc,val_loss,lr))

    print('-' * 69)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/472555.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

运算符重载----赋值运算符重载

运算符重载 本质是函数调用&#xff0c;内置类型编译器直接比&#xff0c;自定义就去找对应类内重载的函数 如果定义在类外&#xff0c;需要访问私有的成员函数&#xff0c;只能将成员函数权限变为Public或者友元&#xff08;非必须不用&#xff09; &#xff0c;所以一般重载…

Linux程序设计之字节序转换

1.在网络通信中&#xff0c;数据的存储方式十分重要&#xff0c;因为它影响到数据的准确性。如今&#xff0c;电脑和网络上数据的存储方式有两种&#xff1a;大端模式和小端模式。大端模式&#xff1a;数据的高位存储在内存的低位&#xff0c;数据的低位存储在内存的高位。小端…

【keil5开发ARM工程时使用STLink调试的技巧分享】

ARM工程开发小技巧系列文章 St link V2驱动安装方法 文章目录 ARM工程开发小技巧系列文章前言一、准备工作1. 硬件连接2. 安装stlink的驱动3. Keil 5配置 二、调试示例1.进入调试状态2. 调试演示2.1 复位&#xff0c;使程序复位到初始位置2.2 单步调试2.3 逐步调试2.4 跳出调…

Quartus中的逻辑锁定与增量编译

逻辑锁定功能可以将FPGA中的代码模块在固定区域实现&#xff0c;优化时序性能&#xff0c;提升设计可靠性。增量编译功能&#xff0c;可以使设计更快速时序收敛&#xff0c;加快编译速度。 LogicLock 使用Chip Planner创建逻辑锁定区域 打开Chip Planner&#xff0c;点击Vie…

Kubernetes Controller原理讲解

Controller原理 在 K8s 中&#xff0c;用户通过声明式 API 定义资源的“预期状态”&#xff0c;Controller 则负责监视资源的实际状态&#xff0c;当资源的实际状态和“预期状态”不一致时&#xff0c;Controller 则对系统进行必要的更改&#xff0c;以确保两者一致&#xff0…

人群计数数据集汇总和详细介绍,全网最全,crowd counting datasets

Crowd Counting数据集汇总 视频监控video surveillance https://github.com/gjy3035/Awesome-Crowd-Counting/blob/master/src/Datasets.md进展 | 密集人群分布检测与计数 :https://www.sohu.com/a/338406719_823210 Free-view 2022_Pedestrian Attribute Recognition htt…

vue+element Ui 树型组件tree懒加载+搜索框远程请求数据为平铺类型

本人之前一直是耕耘后台研发&#xff0c;最近接了个小需求需要接触到vue&#xff0c;记录一下我遇到的一些前端解决时间长的问题 需求&#xff1a; 1&#xff1a;每次动态请求接口获取下一节点数据 2&#xff1a;接口返回的数据是list&#xff0c;不带子节点&#xff0c;用pid来…

Scala中使用Typesafe Config 库

Typesafe Config 库 在 Scala 中加载配置文件有很多种方法&#xff0c;其中一种常用的方法是使用 Typesafe Config 库。该库提供了一种简单易用的方式来读取和解析配置文件。 以下是在启动 main 方法后加载配置文件的示例代码&#xff1a; 引入 Typesafe Config 库 import c…

【MySQL】函数

一、概述 MySQL中提供了大量函数来简化用户对数据库的操作&#xff0c;比如字符串的处理、日期的运算、数值的运算等等。使用函数可以大大提高SELECT语句操作数据库的能力&#xff0c;同时也给数据的转换和处理提供了方便。 &#xff08;在sql中使用函数&#xff09;函数只是对…

shadowsocks服务端和客户端搭建

shadowsocks服务端和客户端搭建 一、服务端搭建 买个境外云服务器&#xff0c;搭建shadowsocks服务端。 需要python3环境。 1.下载shadowsocks服务端python包&#xff0c;并启动。下载地址 # 1.下载 [rootiZrj982e4r5hkd053zsnmqZ ~]# wget https://pypi.python.org/packa…

2023隐私计算与人工智能峰会成功举办!数据宝演讲实录(上篇)分享

2023年4月8日&#xff0c;2023隐私计算与人工智能峰会在深圳举办&#xff0c;大会由华东江苏大数据交易中心和热点资讯联合主办&#xff0c;会上&#xff0c;数据宝董事詹臻女士做开幕式致辞。 数据宝与开放群岛&#xff08;Open Islands&#xff09;进行战略签约&#xff0c;…

LoRA: 大语言模型个性化的最佳实践

出品人&#xff1a; Towhee 技术团队 大型语言模型&#xff08;LLM&#xff09;在今年获得了极大的关注。在以往&#xff0c;预训练微调&#xff08;finetuning&#xff09;成为了让模型适配于特定数据的最佳范式。然而随着大型模型的出现&#xff0c;这种完全微调&#xff08;…

【运动规划算法项目实战】如何实现机器人多目标点导航

文章目录 前言一、 什么是actionlib?二、实现流程三、总结前言 在ROS机器人应用中,实现机器人多目标点导航是非常常见的需求。本文将介绍如何使用ROS和actionlib来实现机器人的多目标点导航,目标点信息将被记录在YAML文件中。 我们可以通过使用MoveBaseAction来实现机器人…

高并发场景下JVM调优实践

一、背景 2021年2月&#xff0c;收到反馈&#xff0c;视频APP某核心接口高峰期响应慢&#xff0c;影响用户体验。 通过监控发现&#xff0c;接口响应慢主要是P99耗时高引起的&#xff0c;怀疑与该服务的GC有关&#xff0c;该服务典型的一个实例GC表现如下图&#xff1a; 可以…

【WinForm】定时器的使用方法除了定时还有延迟执行可用

在使用VS开发工具创建的WinForm项目中&#xff0c;有一个定时器组件&#xff0c;拖出来放上&#xff0c;它只是一个定时处理的作用&#xff0c;不会显示在窗体中。 开发中如果需要定时处理&#xff0c;就使用Timer组件即可&#xff0c; 在它的属性事件一栏里&#xff0c;有一…

offer选择:创业公司 VS 大厂外包

面试拿到两个offer&#xff0c;一个是规模只有几十人的初创小公司&#xff0c;另一个是大厂外包岗位。都是功能测试&#xff0c;两者薪水待遇也差不多&#xff0c;该如何选择?更有利于之后的职业发展...... 这是一个比较典型的问题&#xff0c;对于要转行的同学或者是刚入行没…

【Call for papers】2023年CCF人工智能会议信息汇总(持续更新)

本博文是根据2022年CCF会议推荐的人工智能领域相关会议目录撰写。 注&#xff1a; 由于一些会议的投稿时间还没公开&#xff0c;因此根据往年投稿时间在表格中使用 ~ 符号表示大概的投稿时间&#xff08;一旦会议日期更新&#xff0c;我们也将同步更新博文。若更新不及时请小伙…

C++常用23种设计模式总结(一)------单例模式

什么是单例模式 单例模式是一种设计模式&#xff0c;它保证一个类只有一个实例&#xff0c;并提供一个全局访问点来访问该实例。这个模式通常用于控制资源的访问&#xff0c;例如数据库连接、线程池等。单例模式通过限制实例化操作并提供访问方法&#xff0c;确保在整个应用程序…

【Unity3D小功能】Unity3D中实现模型的旋转、缩放效果(控制摄像机)

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址我的个人博客 大家好&#xff0c;我是佛系工程师☆恬静的小魔龙☆&#xff0c;不定时更新Unity开发技巧&#xff0c;觉得有用记得一键三连哦。 一、前言 其实之前已经写了关于如何控制模型的旋转、移动、缩放效果&…

Android开发:使用sqlite数据库实现记单词APP

一、功能与要求 实现功能&#xff1a;设计与开发记单词系统的四个界面&#xff0c;分别是用户登录、用户注册、单词操作以及忘记密码。 指标要求&#xff1a;通过用户登录、用户注册、单词操作、忘记密码掌握界面设计的基础&#xff0c;其中包括界面布局、常用控件、事件处理等…