N3 - Pytorch文本分类入门

news2024/12/23 18:24:34
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 文本分类的基本流程
  • 常用的数据清洗方法
  • 如何使用jieba实现英文分词
  • 如何构建文本向量
  • 代码实践
    • 数据准备
    • 构建词典
    • 生成数据批次和迭代器
    • 模型设计
    • 模型创建
    • 模型训练
    • 评估模型
  • 总结与心得体会


文本分类的基本流程

文本分类的基本流程

常用的数据清洗方法

如何使用jieba实现英文分词

如何构建文本向量

代码实践

数据准备

使用AG News数据集进行文本分类。

AG News(AG’s News Topic Classification Dataset)是一个广泛用于文本分类任务的数据集,尤其是在新闻领域。该数据集是由AG’s Corpus of News Articles收集整理而来,包含了四个主要类别: 世界、体育、商业和科技。

# 使用torchtext导入数据集
import torch
torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = torch.utils._import_utils.dill_available()

from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')

我们通过打印数据内容查看一下数据集的格式

for i, data in enumerate(train_iter):
    print(data)
    if i == 3:
        break

数据集格式
由此可见数据集的每一个条目是一个元组,包含新闻文章所属的类别和新闻文章的文本内容,其中类别是一个整数,从1到4,分别对应 世界、科技、体育和商业。

构建词典

要构建词典,需要一个分词器,将句子分成分散的词后,再创建词典。也就是上图文本分类任务中的:文本清洗、分词、文本向量化这三步做的事情。

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
	for _, text in data_iter:
		yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
# 给未知单词设置一个默认索引,当一个单词不在词库中,就取默认索引,将它表示为<unk>
vocab.set_default_index(vocab['<unk>'])

get_tokenizer用于获取分词器函数,分词器可以将一个字符串转换成一个单词的列表

print(tokenizer('Here is the example'))

分词器
vocab是使用torchtext的函数构建出的字典对象,可以使用它直接将单词转换为对应的词典序号,然后可以将序号转换为词向量(例如使用one-hot编码)。

print(vocab(['here', 'is', 'the', 'example']))

将单词转换为序号

生成数据批次和迭代器

import torch
from torch.utils.data import DataLoader

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def collate_batch(batch):
	label_list, text_list, offsets = [], [], [0]
	for _label, _text in batch:
		# 将一个批次的标签汇集起来
		label_list.append(label_pipeline(_label))
		# 将一个批次的文本转换成序号汇集起来
		processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
		text_list.append(processed_text)

		# 当前批次中每个句子的长度
		offsets.append(processed_text.size(0))
	label_list = torch.tensor(label_list, dtype=torch.int64)
	text_list = torch.cat(text_list, dim=0)
	offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # 把每个句子的长度累计求合,成为真正的偏移量
	return label_list.to(device), text_list.to(device), offsets.to(device)

# 生成DataLoader
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

模型设计

模型结构
模型的结构如上图所示,对文本进行嵌入后,将句子的嵌入结果进行均值聚合,也就是使用EmbeddingBag mode为mean

from torch import nn
class TextClassificationModel(nn.Module):
	def __init__(self, vocab_size, embed_dim, num_classes):
		super().__init__()
		self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
		self.fc = nn.Linear(embed_dim, num_class)
		self.init_weights()
	
	def init_weights(self):
		initrange = 0.5
		self.embedding.weight.data.uniform_(-initrange, initrange)
		self.fc.weight.data.uniform_(-initrange, initrange)
		self.fc.bias.data.zero_()

	def forward(self, text, offsets):
		embedded = self.embedding(text, offsets)
		return self.fc(embedded)

以上模型中

  • self.embedding 是词嵌入层。作用是将离散的单词表示 (这里直接是单词的词典序号)映射为固定大小的连续向量(也就是单词的向量化)。这些向量捕捉了单词之间的词义关系,并作为网络的输入。
  • self.embedding.weight 是词嵌入层的权重矩阵,它的形状为(vocab_size, embed_dim),其中vocab_size是词汇表的大小,embed_dim是嵌入向量的维度
  • self.embedding.weight.data 是权重矩阵的数据部分,对它进行操作也就直接操作了底层的权重张量
  • .uniform_(-initrange, initrange) 这代表执行了一个原地操作(in-place operation),用于将权重矩阵的值用一个均匀分布进行初始化。均匀分布的范围是[-initrange,initrange],其中initrange是一个正数。

模型创建

num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

创建模型对象

模型训练

定义训练函数与评估函数

import time
def train(dataloader):
	model.train()
	total_acc, train_loss, total_count = 0, 0, 0
	log_interval = 500
	start_time = time.time()
	
	for idx, (label, text, offsets) in enumerate(dataloader):
		predicted_label = model(text, offsets)
		
		optimizer.zero_grad()
		loss = criterion(predicted_label, label)
		loss.backward()
		optimizer.step()

		total_acc += (predicted_label.argmax(1) == label).sum().item()
		train_loss += loss.item()
		total_count += label.size(0)

		if idx % log_interval == 0 and idx > 0:
			elapsed = time.time() - start_time
			print('| epoch {:1d} | {:4d}/{:4d} batches'
				  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader), total_acc/total_count, train_loss/total_count))
			total_acc, train_loss, total_count = 0, 0, 0
			start_time = time.time()

def evaluate(dataloader):
	model.eval()
	total_acc, train_loss, total_count = 0, 0, 0
	
	with torch.no_grad():
		for idx, (label, text, offsets) in enumerate(dataloader):
			predicted_label = model(text, offsets)
			loss = criterion(predicted_label, label)
			total_acc += (predicted_label.argmax(1) == label).sum().item()
			train_loss += loss.item()
			total_count += label.size(0)

	return total_acc/total_count, train_loss/total_count

开始训练

from torch.utils.data import random_split
from torchtext.data.functional import to_map_style_dataset

# 超参数
EPOCHS = 10
LR = 5
BATCH_SIZE = 64

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)

split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
	epoch_start_time = time.time()
	train(train_dataloder)
	val_acc, val_loss = evaluate(valid_dataloader)
	
	if total_accu is not None and total_accu > val_acc:
		scheduler.step()
	else:
		total_accu = val_acc
	print('-'*69)
	print('| epoch {:1d} | time: {:4.2f}s | '
		  'valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss))
	print('-'*69)

在上面的代码中,to_map_style_dataset的作用是将一个迭代的数据集(iterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引更方便地访问数据集中的元素。

在pytorch中,数据集可以分成两种类型,Iterable-style和Map-style。Iterable-style数据集实现了__iter__()方法,可以迭代访问数据集中的元素,但不支持通过索引访问。而Map-style数据集实现了__getitem__()__len__()方法,可以直接通过索引访问特定元素,并能获取数据集的大小。

torchtext是pytorch的一个扩展库,专注于处理文本数据。torchtext.data.functional中的to_map_style_dataset函数可以帮助我们将一个Iterable-style的数据集转换为一个易于操作的Map-style的数据集。然后就可以通过索引直接访问数据集中的特定样本,从而简化训练、验证和测试过程中的数据处理。
训练过程如下:

| epoch 1 |  500/1782 batches| train_acc 0.907 train_loss 0.00429
| epoch 1 | 1000/1782 batches| train_acc 0.906 train_loss 0.00431
| epoch 1 | 1500/1782 batches| train_acc 0.909 train_loss 0.00421
---------------------------------------------------------------------
| epoch 1 time: 6.77s | valid_acc 0.913 valid_loss 0.004
---------------------------------------------------------------------
| epoch 2 |  500/1782 batches| train_acc 0.921 train_loss 0.00369
| epoch 2 | 1000/1782 batches| train_acc 0.920 train_loss 0.00375
| epoch 2 | 1500/1782 batches| train_acc 0.918 train_loss 0.00376
---------------------------------------------------------------------
| epoch 2 time: 6.80s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 3 |  500/1782 batches| train_acc 0.930 train_loss 0.00323
| epoch 3 | 1000/1782 batches| train_acc 0.926 train_loss 0.00334
| epoch 3 | 1500/1782 batches| train_acc 0.925 train_loss 0.00343
---------------------------------------------------------------------
| epoch 3 time: 6.93s | valid_acc 0.860 valid_loss 0.006
---------------------------------------------------------------------
| epoch 4 |  500/1782 batches| train_acc 0.943 train_loss 0.00267
| epoch 4 | 1000/1782 batches| train_acc 0.945 train_loss 0.00263
| epoch 4 | 1500/1782 batches| train_acc 0.946 train_loss 0.00265
---------------------------------------------------------------------
| epoch 4 time: 6.83s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 5 |  500/1782 batches| train_acc 0.947 train_loss 0.00256
| epoch 5 | 1000/1782 batches| train_acc 0.947 train_loss 0.00258
| epoch 5 | 1500/1782 batches| train_acc 0.947 train_loss 0.00261
---------------------------------------------------------------------
| epoch 5 time: 6.76s | valid_acc 0.921 valid_loss 0.004
---------------------------------------------------------------------
| epoch 6 |  500/1782 batches| train_acc 0.948 train_loss 0.00253
| epoch 6 | 1000/1782 batches| train_acc 0.949 train_loss 0.00253
| epoch 6 | 1500/1782 batches| train_acc 0.951 train_loss 0.00241
---------------------------------------------------------------------
| epoch 6 time: 6.93s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 7 |  500/1782 batches| train_acc 0.949 train_loss 0.00248
| epoch 7 | 1000/1782 batches| train_acc 0.949 train_loss 0.00250
| epoch 7 | 1500/1782 batches| train_acc 0.949 train_loss 0.00248
---------------------------------------------------------------------
| epoch 7 time: 6.85s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 8 |  500/1782 batches| train_acc 0.948 train_loss 0.00247
| epoch 8 | 1000/1782 batches| train_acc 0.950 train_loss 0.00250
| epoch 8 | 1500/1782 batches| train_acc 0.951 train_loss 0.00243
---------------------------------------------------------------------
| epoch 8 time: 6.76s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 9 |  500/1782 batches| train_acc 0.951 train_loss 0.00239
| epoch 9 | 1000/1782 batches| train_acc 0.948 train_loss 0.00259
| epoch 9 | 1500/1782 batches| train_acc 0.951 train_loss 0.00244
---------------------------------------------------------------------
| epoch 9 time: 6.87s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------

评估模型

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

评估结果

总结与心得体会

文本分类任务,关键的是前面对文本的处理,合并,嵌入,最后的分类反而非常简直,直接使用了一层全连接层就可以达到不错的效果了。

在复现的过程中,由于使用的库版本不一致导致torchtext库部分代码无法正常运行,卡了好久,后面搜索了一些之前打卡的同学的博客,才找到解决方案。复现模型时尽量不要使用最新的版本,而是使用原来的版本,先运行起来,再改动。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1948712.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【vluhub】zabbix漏洞

介绍&#xff1a; zabbix是对服务器资源状态例如、内存空间、CPU、程序运行状态进行检测、设置预警值、短信设置等功能等一款开源工具。配置不当存在未授权,SQL注入漏洞 弱口令 nameadmin&passwordzabbix nameguest&password POST /index.php HTTP/1.1 Host: 192.1…

[C++实战]日期类的实现

&#x1f496;&#x1f496;&#x1f496;欢迎来到我的博客&#xff0c;我是anmory&#x1f496;&#x1f496;&#x1f496; 又和大家见面了 欢迎来到C探索系列 作为一个程序员你不能不掌握的知识 先来自我推荐一波 个人网站欢迎访问以及捐款 推荐阅读 如何低成本搭建个人网站…

软件测试--测试管理与缺陷管理

文章目录 目标重点/难点 案例引入软件测试管理定义测试组织的定义独立组织测试的优缺点 测试管理——测试计划定义测试计划的持续活动 测试管理的准则出口准则入口准则 软件测试管理 | 测试用例的管理测试用例管理的重要性测试用例管理要解决的问题如何组织测试用例如何报告测试…

关于if return的组合来实现if else效果

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 关于if return的组合来实现if else效果 前言一、if return 前言 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 一、if return // 在链表中插入节点…

Apache POI-Excel入门与实战

目录 一、了解Apache POI 1.1 什么是Apache POI 1.2 为什么要使用ApaChe POI 1.3 Apache POI应用场景 1.4 Apache POI 依赖 二、Apache POI-Excel 入门案例 2.1 写入Excel文件 2.2 读取文件 四、Apache POI实战 4.1 创建一个获取天气的API 4.2高德天气请求API与响应…

vs code解决报错 (c/c++的配置环境 远端机器为Linux ubuntu)

参考链接&#xff1a;https://blog.csdn.net/fightfightfight/article/details/82857397 https://blog.csdn.net/m0_38055352/article/details/105375367 可以按照步骤确定那一步不对&#xff0c;如果一个可以就不用往下看了 目录 一、检查一下文件扩展名 二、安装扩展包并…

秒杀案例-分布式锁Redisson、synchronized、RedLock

模拟秒杀 源码地址前期准备创建数据库表导入数据dependenciespomControllerTSeckillProductTseckillProductServiceTseckillProductServiceImplTseckillProductMapperTseckillProductMapper.xml使用JMeter压力测试开始测试超卖现象原因解决办法更改数据库库存500进行JMeter压力…

linux_top命令打印结果_PID USER PR NI VIRT RES SHR S 什么意思

top命令输出结果 含义 top 命令是 Linux 和 Unix 系统中用于实时显示系统中各个进程的资源占用情况的工具。当你运行 top 命令并查看输出结果时&#xff0c;会看到类似下面的列&#xff08;具体的列可能因 top 的版本和配置而有所不同&#xff09;&#xff1a; PID: 进程ID&a…

NSS [NSSRound#4 SWPU]ez_rce

NSS [NSSRound#4 SWPU]ez_rce CVE-2021-41773 Apache Httpd Server 路径穿越漏洞 POC: GET /cgi-bin/.%2e/%2e%2e/%2e%2e/%2e%2e/bin/sh HTTP/1.1 Host: node4.anna.nssctf.cn:28690 Cache-Control: max-age0 Upgrade-Insecure-Requests: 1 User-Agent: Mozilla/5.0 (Window…

C/C++教程合集(完)

C初级教程(非常基础&#xff0c;适合入门)入门C语言只需一个星期&#xff08;星期一&#xff09;入门C语言只需一个星期&#xff08;星期二&#xff09;入门C语言只需一个星期&#xff08;星期三&#xff09;入门C语言只需一个星期&#xff08;星期四)入门C语言只需一个星期&am…

NSS [NSSRound#13 Basic]flask?jwt?

NSS [NSSRound#13 Basic]flask?jwt? 开题 注册一下 要admin才能拿flag 看看是如何进行身份验证的 是flask session flask-unsign --decode --cookie .eJwtzjESwyAMBMC_UKfghJCEP-MRICZp7bjK5O9xkX6L_aR9HXE-0_Y-rnik_TXTlsiXEhUXleKGGGuG1jbmogrCEmNirZ7BEB-VJbTfIi-26hQD…

数据库实例迁移实践

背景 随着业务发展&#xff0c;数据库实例磁盘逐渐升高&#xff0c;告警频繁&#xff0c;且后续可能会对DDL产生影响&#xff08;尤其是借助ghost等工具执行的DDL&#xff09;。 该实例有多个库&#xff0c;则需要迁移其中的一个或几个单库到其他实例&#xff0c;为什么不做分…

【NPU 系列专栏 3.1 -- - NVIDIA 的 Orin 和 Altan 和 Thor 区别】

请阅读【嵌入式及芯片开发学必备专栏】 文章目录 NVIDIA Orin、Altan 和 ThorNVIDIA Orin 简介NVIDIA Orin 主要特点NVIDIA Orin 应用场景 NVIDIA Altan 简介NVIDIA Altan 主要特点NVIDIA Altan 应用场景 NVIDIA Thor 简介NVIDIA Thor 主要特点NVIDIA Thor 应用场景 与 Hopper …

CTF-NSSCTF题单[GKCTF2020]

[GKCTF 2020]CheckIN 这道题目考察&#xff1a;php7-gc-bypass漏洞 打开这道题目&#xff0c;开始以为考察反序列化&#xff0c;但实际并不是&#xff0c;这里直接用$_REQUEST传入了参数便可以利用了。这里出现了一个eval&#xff08;&#xff09;函数&#xff0c;猜测考察命…

暑期C++ 缺省参数

有任何不懂的问题可以评论区留言&#xff0c;能力范围内都会一一回答 1.缺省参数的概念 缺省参数是是声明或定义参数时为函数的参数指定一个缺省值。在调用该函数值时&#xff0c;如果没有指定实参则采用该形参的缺省值&#xff0c;否则使用指定的实参 看了上面定义后&#…

CogVLMv2环境搭建推理测试

引子 之前写过一篇CogVLM的分享&#xff0c;感兴趣的移步CogVLM/CogAgent环境搭建&推理测试-CSDN博客&#xff0c;前一阵子&#xff0c;CogVLMv2横空出世&#xff0c;支持视频理解功能&#xff0c;OK&#xff0c;那就让我们开始吧。 一、模型介绍 CogVLM2 系列模型开源了…

基于Vision Transformer的mini_ImageNet图片分类实战

【图书推荐】《PyTorch深度学习与计算机视觉实践》-CSDN博客 PyTorch计算机视觉之Vision Transformer 整体结构-CSDN博客 mini_ImageNet数据集简介与下载 mini_ImageNet数据集节选自ImageNet数据集。ImageNet是一个非常有名的大型视觉数据集&#xff0c;它的建立旨在促进视觉…

旗晟机器人仪器仪表识别AI智慧算法

在当今迅猛发展的工业4.0时代&#xff0c;智能制造和自动化运维已然成为工业发展至关重要的核心驱动力。其中智能巡检运维系统扮演着举足轻重的角色。工业场景上不仅要对人员行为监督进行监督&#xff0c;对仪器仪表识别分析更是不可缺少的一个环节。那么我们说说旗晟仪器仪表识…

AI模型大比拼:Claude 3系列 vs GPT-4系列最新模型综合评测

AI模型大比拼&#xff1a;Claude 3系列 vs GPT-4系列最新模型综合评测 引言 人工智能技术的迅猛发展带来了多款强大的语言模型。本文将对六款领先的AI模型进行全面比较&#xff1a;Claude 3.5 Sonnet、Claude 3 Opus、Claude 3 Haiku、GPT-4、GPT-4o和GPT-4o Mini。我们将从性能…

【Gin】精准应用:Gin框架中工厂模式的现代软件开发策略与实施技巧(下)

【Gin】精准应用&#xff1a;Gin框架中工厂模式的现代软件开发策略与实施技巧(下) 大家好 我是寸铁&#x1f44a; 【Gin】精准应用&#xff1a;Gin框架中工厂模式的现代软件开发策略与实施技巧(下)✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 前言 本次文章分为上下两部分&…