第20周:Pytorch文本分类入门

news2025/2/28 13:59:21

目录

前言

一、前期准备

1.1 环境安装导入包

1.2 加载数据

1.3 构建词典

1.4 生成数据批次和迭代器

二、准备模型

2.1 定义模型

2.2 定义示例

2.3 定义训练函数与评估函数

三、训练模型

3.1 拆分数据集并运行模型

3.2 使用测试数据集评估模型

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
  • 🍖 原作者:[K同学啊]

说在前面

本周任务:了解文本分类的基本流程、学习常用数据清洗方法、学习如何使用jieba实现英文分词、学习如何构建文本向量

我的环境:Python3.8、Pycharm2020、torch1.12.1+cu113

数据来源:[K同学啊]


一、前期准备

1.1 环境安装导入包

本文是一个使用Pytorch实现的简单文本分类实战案例,在本案例中,我们将使用AG News数据集进行文本分类

需要确保已经安装了torchtext与poralocker库

PS:torchtext库的安装需要与Pytorch、python的版本进行匹配,具体可参考torchtext版本对应

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
from torchtext.datasets import AG_NEWS
from torchtext.vocab import build_vocab_from_iterator
import torchtext.data.utils as utils
from torch.utils.data import DataLoader
from torch import nn
import time

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

1.2 加载数据

本文使用的数据集是AG News(AG's News Topic Classification Dataset)是一个广泛用于文本分;分类任务的数据集,尤其是在新闻领域,该数据集是由AG’s Corpus of News Ariticles收集整理而来,包含了四个主要的类别:世界、体育、商业和科技

torchtext.datasets.AG_NEWS是一个用于加载AG News数据集的torchtext数据集类,具体的参数如下:

  • root:数据集的根目录,默认值是'.data'
  • split:数据集的拆分train、test
  • **kwargs:可选的关键字参数,可传递给torchtext.datasets.TextClassificationDataset类构造的函数

该类加载的数据集是一个列表,其中每个条目都是一个元祖,包含以下两个元素

  • 一条新闻文章的文本内容
  • 新闻文章所属的类别(一个整数,从1到4,分别对应世界、科技、体育和商业)

代码如下:

train_iter = AG_NEWS(split='train')

1.3 构建词典

torchtext.data.ttils.get_tokenizer()是一个用于将文本数据分词的函数,它返回一个分词器(tokenizer)函数,可以将一个字符串转换成一个单词的列表

函数原型:torchtext.data.ttils.get_tokenizer(tokenizer, language = 'en')

tokenizer参数是用于指定使用的分词器名称,可以是一下之一

  • basic_english:用于基本英文文本的分词器
  • moses:用于处理各种语言的分词器,支持多种选项
  • spacy:使用spaCy分词器,需要安装spaCy库
  • toktok:用于各种语言的分词器,速度较快

PS:分词器函数返回的单词列表中不包含任何标点符号或空格

代码如下:

tokenizer = utils.get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])  #设置默认索引,如果找不到单词,则会选择默认索引

print(vocab(['here', 'is', 'an', 'example']))

text_pipeline = lambda  x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

print(text_pipeline('here is an example'))
print(label_pipeline('10'))

输出结果:

[475, 21, 30, 5297]
[475, 21, 30, 5297]
9

1.4 生成数据批次和迭代器

代码如下:

def collate_batch(batch):
    label_list, text_list, offsets = [],[],[0]

    for (_label, _text) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))

        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)
        text_list.append(processed_text)

        #偏移量,即语句的总词汇量
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)

    return label_list.to(device),text_list.to(device), offsets.to(device)

#数据集加载器
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

二、准备模型

2.1 定义模型

首先对文本进行嵌入,然后对句子嵌入之后的结果进行均值聚合

代码如下:

#2.1 准备模型
class TextClassificationModel(nn.Module):
    def __init__(self,vocab_size,embed_dim,num_calss):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_calss)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

self.embedding.weight.data.uniform_(-initrange, initrange)是在PyTorch框架下用于初始化神经网络的词嵌入层(embedding layer)权重的一种方法,这里使用了均匀分布的随机值来初始化权重,具体来说,其作用如下:

  • self.embedding:这是神经网络的词嵌入层(embedding layer),其嵌入层的作用是将离散的单词表示(通常为整数索引)映射为固定大小的连续向量,这些向量捕捉了单词之间的语义关系,并作为网络的输入
  • self.embedding.weight:这是词嵌入层的权重矩阵,它的形状为(vocab_size,embedding_dim),其中vocab_size是词汇表的大小,embedding_dim是嵌入向量的维度
  • self.embedding.weight.data:这是权重矩阵的数据部分,可以直接操作其底层的张量
  • .uniform_(-initrange, initrange):这是一个原地操作,用于将权重矩阵的值用一个均匀分布进行初始化,均匀分布的范围为[-initrange,initrange],其中initrange是一个正数

这种方式初始化词嵌入层的权重,可以使得模型在训练开始时具有一定的随机性,有助于避免梯度消失或梯度爆炸等问题,在训练过程中,这些权重将通过优化算法不断更新,以捕捉到更好的单词表示

2.2 定义示例

代码如下:

#2.2定义实例
num_class = len(set([label for (label,text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

epochs = 10
lr = 5
batch_szie =64

critertion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)

2.3 定义训练函数与评估函数

#2.3 定义训练函数与评估函数
def train(dataloader):
    model.train()
    total_acc, train_loss, total_count = 0,0,0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()
        loss = critertion(predicted_label, label)
        loss.backward()
        optimizer.step()

        #记录acc和loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx>0:
            elapsed = time.time() - start_time
            print('| epoch{:1d} | {:4d}/{:4d} batches'
                  'train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                                                total_acc/total_count, train_loss/total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, train_loss, total_count = 0, 0 ,0
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = critertion(predicted_label, label)
            # 记录acc和loss
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    return total_acc/total_count, train_loss/total_count

三、训练模型

3.1 拆分数据集并运行模型

     torchtext.data.functional.to_map_style_dataset函数的作用是将一个迭代式的数据集(Iterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引更方便访问数据集中的元素

       在PyTorch中,数据集可以分为两种类型:Iterable-style和Map-style,Iterable-style数据集实现了__iter__()方法,可以迭代访问数据集中的元素,但不支持通过索引访问,而Map-style数据集实现了__getitem__()和__len__()方法,可以直接通过索引访问特定元素,并能获取数据集的大小、

      TorchText是Pytorch的一个扩展库,专注于处理文本数据,torchtext.data.functional中的to_map_style_dataset函数可以帮助我们将一个Iterable-style数据集转换为一个易于操作的Map-style数据集,这样就可以通过索引直接访问数据集中的特定样本,从而简化了训练、验证和测试过程中的数据处理。

代码如下:

#3.1 拆分数据集并运行模型
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

total_accu = None

train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset)-num_train])

train_dataloader = DataLoader(split_train_, batch_size=batch_szie, shuffle=True, collate_fn=collate_batch)
vaild_dataloader = DataLoader(split_valid_, batch_size=batch_szie, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=batch_szie, shuffle=True, collate_fn=collate_batch)


for epoch in range(1, epochs+1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(vaild_dataloader)

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch{:1d} | time: {:4.2f}s |'
          'vaild_acc {:4.3f} vaild_loss {:4.3f}'.format(epoch,
                                                        time.time()-epoch_start_time,
                                                        val_acc, val_loss))
    print('-' * 69)

输出结果:

| epoch1 |  500/1782 batchestrain_acc 0.714 train_loss 0.01141
| epoch1 | 1000/1782 batchestrain_acc 0.864 train_loss 0.00623
| epoch1 | 1500/1782 batchestrain_acc 0.879 train_loss 0.00551
---------------------------------------------------------------------
| epoch1 | time: 7.21s |vaild_acc 0.810 vaild_loss 0.008
---------------------------------------------------------------------
| epoch2 |  500/1782 batchestrain_acc 0.905 train_loss 0.00445
| epoch2 | 1000/1782 batchestrain_acc 0.905 train_loss 0.00440
| epoch2 | 1500/1782 batchestrain_acc 0.906 train_loss 0.00442
---------------------------------------------------------------------
| epoch2 | time: 5.68s |vaild_acc 0.910 vaild_loss 0.004
---------------------------------------------------------------------
| epoch3 |  500/1782 batchestrain_acc 0.915 train_loss 0.00389
| epoch3 | 1000/1782 batchestrain_acc 0.918 train_loss 0.00381
| epoch3 | 1500/1782 batchestrain_acc 0.918 train_loss 0.00377
---------------------------------------------------------------------
| epoch3 | time: 6.04s |vaild_acc 0.910 vaild_loss 0.004
---------------------------------------------------------------------
| epoch4 |  500/1782 batchestrain_acc 0.929 train_loss 0.00331
| epoch4 | 1000/1782 batchestrain_acc 0.921 train_loss 0.00357
| epoch4 | 1500/1782 batchestrain_acc 0.926 train_loss 0.00341
---------------------------------------------------------------------
| epoch4 | time: 6.54s |vaild_acc 0.890 vaild_loss 0.005
---------------------------------------------------------------------
| epoch5 |  500/1782 batchestrain_acc 0.941 train_loss 0.00280
| epoch5 | 1000/1782 batchestrain_acc 0.945 train_loss 0.00266
| epoch5 | 1500/1782 batchestrain_acc 0.944 train_loss 0.00265
---------------------------------------------------------------------
| epoch5 | time: 6.76s |vaild_acc 0.917 vaild_loss 0.004
---------------------------------------------------------------------
| epoch6 |  500/1782 batchestrain_acc 0.948 train_loss 0.00255
| epoch6 | 1000/1782 batchestrain_acc 0.946 train_loss 0.00265
| epoch6 | 1500/1782 batchestrain_acc 0.946 train_loss 0.00260
---------------------------------------------------------------------
| epoch6 | time: 6.80s |vaild_acc 0.920 vaild_loss 0.004
---------------------------------------------------------------------
| epoch7 |  500/1782 batchestrain_acc 0.948 train_loss 0.00254
| epoch7 | 1000/1782 batchestrain_acc 0.945 train_loss 0.00266
| epoch7 | 1500/1782 batchestrain_acc 0.949 train_loss 0.00248
---------------------------------------------------------------------
| epoch7 | time: 6.52s |vaild_acc 0.915 vaild_loss 0.004
---------------------------------------------------------------------
| epoch8 |  500/1782 batchestrain_acc 0.949 train_loss 0.00246
| epoch8 | 1000/1782 batchestrain_acc 0.949 train_loss 0.00246
| epoch8 | 1500/1782 batchestrain_acc 0.949 train_loss 0.00252
---------------------------------------------------------------------
| epoch8 | time: 6.75s |vaild_acc 0.919 vaild_loss 0.004
---------------------------------------------------------------------
| epoch9 |  500/1782 batchestrain_acc 0.948 train_loss 0.00251
| epoch9 | 1000/1782 batchestrain_acc 0.949 train_loss 0.00247
| epoch9 | 1500/1782 batchestrain_acc 0.950 train_loss 0.00245
---------------------------------------------------------------------
| epoch9 | time: 6.87s |vaild_acc 0.919 vaild_loss 0.004
---------------------------------------------------------------------
| epoch10 |  500/1782 batchestrain_acc 0.951 train_loss 0.00244
| epoch10 | 1000/1782 batchestrain_acc 0.947 train_loss 0.00255
| epoch10 | 1500/1782 batchestrain_acc 0.949 train_loss 0.00244
---------------------------------------------------------------------
| epoch10 | time: 6.64s |vaild_acc 0.919 vaild_loss 0.004
---------------------------------------------------------------------

3.2 使用测试数据集评估模型

代码如下:

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

输出结果:


总结

了解文本分类的基本流程、学习常用数据清洗方法、学习如何使用jieba实现英文分词、学习如何构建文本向量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1988108.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

游戏盾是什么,如何保护网络游戏的安全

在数字化浪潮的推动下&#xff0c;网络游戏已成为人们休闲娱乐不可或缺的一部分。然而&#xff0c;随着游戏行业的蓬勃发展&#xff0c;网络安全问题也日益严峻&#xff0c;黑客攻击频发&#xff0c;给游戏玩家和游戏运营商带来了巨大困扰。为了应对这些挑战&#xff0c;应用加…

机器学习·L2W3-模型评估

模型评估 划分数据集为训练集、验证集、测试集 60%训练集、20%测试集和验证集 x_train,x_,y_train,y_train_test_split(X_train,y_train,test_size0.4) x_cv,x_test,y_cv,y_testtrain_test_split(x_train,y_train,test_size0.5)交叉验证-模型选择 使用交叉验证计算模型的损失…

新来的小姐姐,微软便笺程序打不开了

网管小贾 / sysadm.cc 公司新来了一位小姐姐&#xff0c;听说跟老板沾点关系。 这一天老板出差&#xff0c;午休时大家趁着小姐姐去取外卖&#xff0c;开始了各自的调侃。 部门主管丽姐开了个头&#xff0c;当着众人先抱怨上了。 “你们看看&#xff0c;你们看看&#xff0c;…

国内顶级 AI 的回答令人“贻笑大方”:看来苹果秃头码农们暂时还不会失业吧?

概览 在苹果 App 的日常开发中&#xff0c;利用 Xcode 预览可以帮我们极大的提高界面调试的效率。而且&#xff0c;若能进一步判断出当前 App 是否运行在 Preview 环境中则会更让秃头码农们“笑逐颜开”。 那么到底有没有简单的方法来完成这一任务呢&#xff1f;答案是肯定的…

苹果数据恢复攻略:3大秘籍,助你重建“数据高塔”

在数字时代&#xff0c;苹果设备如iPhone、iPad和Mac已成为我们生活中不可或缺的一部分&#xff0c;存储着大量珍贵的照片、视频、文件和联系信息。然而&#xff0c;意外的删除、系统更新或硬件故障等问题时常威胁着数据的安全。当数据“高塔”崩塌时&#xff0c;苹果数据恢复要…

海量数据处理商用短链接生成器平台 - 6

第十二章 海量数据下的分库分表技术栈讲解 第1集 大话业界常见数据库分库分表中间件介绍 简介&#xff1a; 大话业界常见分库分表中间件介绍 业界常见分库分表中间件 Cobar&#xff08;已经被淘汰没使用了&#xff09;TDDL 淘宝根据自己的业务特点开发了 TDDL &#xff08;T…

基于JSP的智能仓储系统

你好&#xff0c;我是专注于智能系统开发的码农小野。如果对智能仓储系统感兴趣&#xff0c;欢迎私信交流。 开发语言 Java 数据库 MySQL 技术 JSP技术 工具 MyEclipse、Tomcat 系统展示 首页 [插入论文中的系统首页图片] 管理员功能界面 员工功能界面 供应商功能界…

MATLAB代码下载|蚁群算法|计算一元函数最小值

程序总述 程序使用蚁群优化的方法&#xff0c;计算一元函数&#xff08;单输入单输出非线性函数&#xff09;在定义域内的最小值。 函数形式 待计算最小值的函数形式如下&#xff1a; x 4 − 0.2 ∗ c o s ( 3 x ∗ π ) 0.6 x^4 - 0.2 * cos(3x * \pi) 0.6 x4−0.2∗cos…

AI新应用:概要设计与详细设计自动生成解决方案

近日&#xff0c;CoCode旗下的Co-Project智能项目管理平台V4.0.0升级发布&#xff0c;新增AI生成概要设计和AI生成详细设计功能&#xff0c;大大提高了设计的效率和质量。 CoCode旗下的Co-Project智能项目管理平台 一键智绘蓝图自现 平台设计板块新增概要设计功能&#xff0c;…

有点恶心,但是一周可以拿5个大模型岗offer,非常详细收藏我这一篇就够了

一、基础篇目前主流的开源模型体系有哪些&#xff1f; Transformer体系&#xff1a;由Google提出的Transformer模型及其变体&#xff0c;如BERT、GPT等。 PyTorch Lightning&#xff1a;一个基于PyTorch的轻量级深度学习框架&#xff0c;用于快速原型设计和实验。TensorFlow Mo…

同声传译翻译器哪个好?评测5款实用的同声传译翻译器

想象一下&#xff0c;在国际会议中&#xff0c;演讲者的声音刚落&#xff0c;耳机里便响起清晰的母语翻译&#xff1b;或是观看一部外语电影&#xff0c;无需眼睛离开屏幕&#xff0c;字幕就自动以你熟悉的语言呈现——这不再是科幻电影里的桥段&#xff0c;而是现实生活中同声…

猫头虎分享:CSDN博客最多可以创建多少个专栏?

&#x1f42f; 猫头虎分享&#xff1a;CSDN博客最多可以创建多少个专栏&#xff1f; 摘要 &#x1f4cb; 在CSDN博客平台上&#xff0c;不同级别的用户可以创建的专栏数量有所不同。本文将详细介绍CSDN博客创建专栏的具体数量限制&#xff0c;并且对不同等级用户所能创建的专…

武汉流星汇聚:亚马逊赋能中国卖家,全球市场份额优势引领出海潮流

在全球电商的浩瀚星空中&#xff0c;亚马逊无疑是最耀眼的星辰之一&#xff0c;其卓越的市场占有率不仅巩固了自身在全球电商市场的领导地位&#xff0c;更为中国卖家出海提供了前所未有的机遇与优势。随着中国卖家对海外市场的探索日益深入&#xff0c;亚马逊平台以其独特的优…

最新Thinphp开发的证书查询系统源码/开源版/支持自适应多端PC+手机站+含安装教程

源码简介&#xff1a; 最新Thinphp开发的证书查询系统源码&#xff0c;它是开源版&#xff0c;别小看这个开源版&#xff0c;它可是能自动适应各种屏幕大小&#xff0c;不管是用手机还是电脑&#xff0c;都能轻松查证书。附上了安装教程。 这款精心开发用PHP打造的证书查询系…

vue学习--02天

一、数据绑定 !DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title><script src&q…

React 知识点(二)

文章目录 一、React 组件二、React 组件通信 - 父子通信三、React 组件通信 - 子父通信四、React 组件通信 - 兄弟通信五、React 组件通信 - 跨组件通信(祖先)六、结合组件通信案例七、props-children 属性八、props-类型校验九、React 生命周期十、setState 扩展 一、React 组…

https证书怎么申请?

申请SSL证书的步骤可以因不同的证书颁发机构&#xff08;CA&#xff09;和证书类型&#xff08;如DV SSL、OV SSL、EV SSL&#xff09;而有所差异。以下是一个通用的SSL证书申请流程&#xff0c;以供参考&#xff1a; 1. 选择SSL证书类型 首先&#xff0c;需要根据您的需求选…

SQLE:你的SQL全生命周期质量管理平台

SQLE&#xff1a;你的SQL全生命周期质量管理平台 在数据库管理领域&#xff0c;总有那么几个难题让人头疼。今天要介绍的SQLE&#xff0c;就是解决这些问题的利器。它不仅支持多种数据库&#xff0c;还能在事前控制、事后监督、标准发布等场景中大显身手。本文将为你详细介绍SQ…

【学习方法】高效学习因素 ② ( 学习动机 | 内在学习动机 | 外在学习动机 | 外在学习动机的调整方向 | 保护学习兴趣 | 高考竞争分析 )

文章目录 一、高效学习的其它因素 - 学习动机1、学习动机2、内在学习动机3、外在学习动机4、外在学习动机的问题所在5、外在学习动机的调整方向6、保护学习兴趣7、高考竞争分析 上一篇博客 【学习方法】高效学习因素 ① ( 开始学习 | 高效学习因素五大因素 | 高效学习公式 - 学…

二十八、【人工智能】【机器学习】- 隐马尔可夫模型 (Hidden Markov Models, HMMs)

系列文章目录 第一章 【机器学习】初识机器学习 第二章 【机器学习】【监督学习】- 逻辑回归算法 (Logistic Regression) 第三章 【机器学习】【监督学习】- 支持向量机 (SVM) 第四章【机器学习】【监督学习】- K-近邻算法 (K-NN) 第五章【机器学习】【监督学习】- 决策树…