N4 - Pytorch实现中文文本分类

news2024/11/13 14:55:26
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 任务描述
  • 步骤
    • 环境设置
    • 数据准备
    • 模型设计
    • 模型训练
    • 模型效果展示
  • 总结与心得体会


任务描述

在上周的任务中,我们使用torchtext下载了托管的英文的AG News数据集 进行了分类任务。本周我们来对中文的自定义数据集来进行分类任务。

自定义数据集的格式是csv格式,我们先用pandas进行读取,创建数据集对象。然后后面的步骤就和上周基本上一致了。

步骤

环境设置

import torch
import warnings

warnings.filterwarnings('ignore') # 忽略警告

# 创建全局设备对象
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

设备对象

数据准备

使用pandas读取数据

import pandas as pd
train_data = pd.read_csv('train.csv', sep='\t', header=None)
train_data.head()

查看前三条数据
可以看到数据有两列,第一列是文字内容,第二列是所属的标签。

接下来编写一个迭代器函数,每次迭代返回一对内容和标签

def custom_data_iter(texts, labels):
	for x, y in zip(texts, labels):
		yield x, y
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])

然后创建词典,使用torchtext中的build_vocab_from_iterator工具函数

from torchtext.vocab import build_vocab_from_iterator
import jieba

# 使用jieba库来做分词器
tokenizer = jieba.lcut # lcut直接返回列表, cut 返回一个迭代器

# 编写一个迭代函数,每次返回一句内容的分词结果
def yield_tokens(data_iter):
	for text, _ in data_iter: # 每次返回一句内容和对应标签
		yield tokenizer(text) # 返回该句内容的分词列表

# 创建词典
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

# 测试词典
vocab(['我', '想', '看', '和平', '精英', '上', '战神', '必备', '技巧', '的', '游戏', '视频'])

词典结果
获取所有的标签名

label_name = list(set(train_data[1].values[:]))
print(label_name)

标签名列表
编写函数,将内容和标签分别转换成数值

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline("我想看和平精英上战神必备技巧的游戏视频"))
print(label_pipeline('Video-Play'))

转换函数
编写文本的批处理函数,用于数据集与模型之间,将一个批次的文本数据转换为数值,还需要生成EmbeddingBag输入时的offsets参数。

from torch.utils.data import DataLoader

def collate_batch(batch):
	label_list, text_list, offsets = [], [], [0]
	
	for (_text, _label) in batch:
		# 标签列表
		label_list.append(label_pipeline(_label))
		# 文本列表
		processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
		text_list.append(processed_text)
		# 偏移列表
		offsets.append(len(processed_text))
	label_list = torch.tensor(label_list, dtype=torch.int64)
	text_list = torch.cat(text_list)
	offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)

	return text_list.to(device), label_list.to(device), offsets.to(device)

dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

模型设计

和上节一样,一个EmbeddingBag层跟着一个全连接层就可以了

from torch import nn

class TextClassificationModel(nn.Module):
	# 参数随后设置
	def __init__(self, vocab_size, embed_dim, num_classes):
		super().__init__()
		self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
		self.fc = nn.Linear(embed_dim, num_classes)
		self.init_weights()
	
	# 自定义的权重初始化操作
	def init_weights(self):
		initrange = 0.5
		self.embedding.weight.data.uniform_(-initrange, initrange)
		self.fc.weight.data.uniform_(-initrange, initrange)
		self.fc.bias.data.zero_()
	
	# 向前传播
	def forward(self, text, offsets):
		embedded = self.embedding(text, offsets)
		return self.fc(embedded)

创建模型对象

num_classes = len(label_name) # 分类数量
vocab_size = len(vocab) # 词典大小
embedding_size = 64 # 嵌入向量的维度
model = TextClassificationModel(vocab_size, embedding_size, num_classes).to(device)
model

模型结构
可以看到,这个模型简单的很。

模型训练

首先编写训练和评估函数

def train(dataloader):
	model.train()
	total_acc, train_loss, total_count = 0, 0, 0
	log_interval = 50
	start_time = time.time()
	
	for idx, (text, label, offsets) in enumerate(dataloader):
		predicted_label = model(text, offsets)
		
		optimizer.zero_grad()
		loss = criterion(predicted_label, label)
		loss.backward()
		nn.utils.clip_grad_norm_(model.parameters(), 0.1) #梯度裁剪
		optimizer.step()

		total_acc += (predicted_label.argmax(1) == label).sum().item()
		train_loss += loss.item()
		total_count += label.size(0)

		if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader), total_acc/total_count, train_loss, total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
	model.eval()
	total_acc, train_loss, total_count = 0, 0, 0

	with torch.no_grad():
		for idx, (text, label, offsets) in enumerate(dataloader):
			predicted_label = model(text, offsets)

			loss = criterion(predicted_label, label)
			total_acc += (predicted_label.argmax(1) == label).sum().item()
			train_loss += loss.item()
			total_count += label.size(0)

	return total_acc/total_count, train_loss/total_count

开始训练

from torch.utils.data import random_split
from torchtext.data.functional import to_map_style_dataset

# 迭代次数
EPOCHS = 20
# 学习率
LR = 5
# 批次大小
BATCH_SIZE = 64

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

total_accu = None
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

train_size = int(len(train_dataset)*0.8)
split_train_, split_valid_ = random_split(train_dataset, [train_size, len(train_dataset) - train_size])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS+1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    lr = optimizer.state_dict()['param_groups'][0]['lr']
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-'*69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss, lr))
    print('-'*69)

训练过程
训练结束后打印一下模型的准确度

model = model.to(device)
test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为: {:5.4f}'.format(test_acc))

模型准确率

模型效果展示

自己写一句话让模型跑一下看看效果

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item()

ex_text_str = '不要让一个男人听懂《水星记》'

# 切换成CPU推理
model = model.to('cpu')
print('文本的分类是: %s' % label_name[predict(ex_text_str, text_pipeline)])

测试结果

总结与心得体会

通过测试,发现这个模型的效果还是不错的。大部分的句子可以给出正确的分类。和上节相比,中文数据集的文本分类任务和英文数据集的文本分类主要差异在tokenizer(分词器)上。英文的分词非常简单,英文的词之间天然有间隔,所以可以直接使用标点和空格来分割。中文就不太一样,中文需要一个好的断句工具才行,jieba库就是这么一个工具。在大部分的中文自然语言处理任务中,都可以看到它的身影。我在想是不是可以直接使用深度学习来进行分词,来达到更好的效果,或者直接使用大语言模型,经过Prompt直接变成分词工具来使用(只不过成本太高了),希望有时间可以尝试一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1973316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Web 框架

Web 框架 Web服务器Web服务器的主要功能常见的Web服务器软件包 Web 框架常用 Python Web 框架选择Python Web框架的考虑因素 WSGIWSGI的主要特点WSGI的工作原理常见的WSGI服务器和框架&#xff1a; 静态资源定义与特点静态资源的类型静态资源的管理与优化 动态资源定义与特点动…

Java入门、进阶、强化、扩展、知识体系完善等知识点学习、性能优化、源码分析专栏分享

场景 作为一名Java开发者&#xff0c;势必经历过从入门到自学、从基础到进阶、从学习到强化的过程。 当经历过几年企业级开发的磨炼&#xff0c;再回头看之前的开发过程、成长阶段发现确实是走了好多的弯路。 作为一名终身学习的信奉者&#xff0c;秉承Java体系需持续学习、…

Java | Leetcode Java题解之第316题去除重复字母

题目&#xff1a; 题解&#xff1a; class Solution {public String removeDuplicateLetters(String s) {boolean[] vis new boolean[26];int[] num new int[26];for (int i 0; i < s.length(); i) {num[s.charAt(i) - a];}StringBuffer sb new StringBuffer();for (in…

arduino程序-MC猜数字2、3、4(基础知识)

arduino程序-MC猜数字2、3、4&#xff08;基础知识&#xff09; 1-20 MC猜数字2-LED数码管数码管LED数码管应用程序示例 1-21 MC猜数字3- while回顾While循环语句Do while循环语句 1-22 MC猜数字4-switch caseIf判断myNumber数字显示If ... else ifSwitch case示例程序产生随机…

域环境的搭建 内网学习不会搭建环境???

今天有空写一下内网环境的搭建的步骤&#xff0c;我下面是我搭建的环境的图。 我搭建的是父子域&#xff0c;DC是父域控&#xff0c;WEB为子域控 然后下面我来说一下我是怎么搭建的。 首先我们要准备一些机器的镜像文件&#xff0c;如果你是复制的虚拟机的话&#xff0c;你要把…

CLIP论文详解

文章目录 前言一、CLIP理论1.CLIP思想2.模型结构 二、CLIP预训练1.数据集2.训练策略3.模型选择 三、Zero-Shot推理四、CLIP伪代码实现五、CLIP局限性总结 前言 CLIP这篇论文是OpenAI团队在2021年2月底提出的一篇论文&#xff0c;名字叫做《Learning Transferable Visual Models…

Markdown与数学公式

在写偏理科的文章的时候&#xff0c;多多少少会涉及到一些公式、函数的输入&#xff0c;本篇就来讲讲如何在 Markdown 中书写数学公式。 在此之前&#xff0c;我们先介绍下 LaTex 和 MashJax&#xff0c;Markdown 就是基于它们来实现数学公式的输入。 LaTex 简介 LaTex &…

JDK-java.nio包详解

JDK-java.nio包详解 概述 一直以来Java三件套&#xff08;集合、io、多线程&#xff09;都是最热门的Java基础技术点&#xff0c;我们要深入掌握好这三件套才能在日常开发中得心应手&#xff0c;之前有编写集合相关的文章&#xff0c;这里出一篇文章来梳理一下io相关的知识点。…

MyBatis 源码学习 | Day 1 | 了解 MyBatis

什么是 MyBatis 在对一项技术进行深入学习前&#xff0c;我们应该先对它有个初步的认识。MyBatis 是一个 Java 持久层框架&#xff0c;用于简化数据库的操作。它通过 XML 或注解的方式配置和映射原始类型、接口和 Java POJO&#xff08;Plain Old Java Objects&#xff0c;普通…

跑深度学习模型I:一文正确使用CUDA

1. 安装显卡驱动NVIDIA 如果出现这个问题&#xff0c;是NVIDIA环境配置原因。一定要注意配置系统环境变量正确。 C:\Users\2605304845>nvcc --version nvcc 不是内部或外部命令&#xff0c;也不是可运行的程序 或批处理文件。 - CSDN文库 2. 安装CUDA 安装时注意版本对应…

C语言--函数

1. 函数定义 语法&#xff1a; 类型标识符 函数名&#xff08;形式参数&#xff09; {函数体代码 } &#xff08;1&#xff09;类型标识符 --- 数据类型&#xff08;函数要带出的结果的类型&#xff09; 注&#xff1a;数组类型不能做函数返回结果的类型&#xff0c;如果函…

pt模型转onnx模型,onnx模型转engine模型,pt模型转engine模型详细教程(TensorRT,jetpack)

背景 背景是需要在nvidia jetpack4.5.1的arm64设备上跑yolov8,用TensorRT加速&#xff0c;需要用*.engine格式的模型&#xff0c;但是手头上的是pt格式模型&#xff0c;众所周知小板子的内存都很小&#xff0c;连安装ultralytics依赖库的容量都没有&#xff0c;所以我想到在wi…

【开源】嵌入式Linux(IMX6U)应用层综合项目(1)--云平台调试APP

目录 1.简介 1.1功能介绍 1.2技术栈介绍 1.3演示视频 1.4硬件介绍 2.软件设计 2.1连接阿里云 2.2云平台调试UI 2.3Ui_main.c界面切换处理文件 2.4.main函数 3.结尾&#xff08;附网盘链接&#xff09; 1.简介 此文章并不是教程&#xff0c;只能当作笔者的学习分享&…

go中的值传递和指针传递

文章目录 1、& 和 *2、空指针3、nil4、用值传递还是指针传递&#xff1f;5、补充 1、& 和 * &后跟一个变量名&#xff0c;得到的是这个变量的内存地址*int类型的变量&#xff0c;代表这个变量里存的值是int类型的变量的内存地址数据类型的指针类型&#xff0c;即在…

顺序表的实现【数据结构】

1.线性表 线性表&#xff08;linear list&#xff09;是n个具有相同特性的数据元素的有线序列。线性表是一种在实际中广泛使用的数据结构&#xff0c;常见的线性表有&#xff1a;顺序表、链表、栈、队列、字符串… 线性表在逻辑上是线性结构&#xff0c;也就是说是连续的一条线…

医院设置(洛谷)

设有一棵二叉树&#xff0c;如图&#xff1a; 其中&#xff0c;圈中的数字表示结点中居民的人口。圈边上数字表示结点编号&#xff0c;现在要求在某个结点上建立一个医院&#xff0c;使所有居民所走的路程之和为最小&#xff0c;同时约定&#xff0c;相邻接点之间的距离为 11。…

C语言实现 -- 单链表

C语言实现 -- 单链表 1.顺序表经典算法1.1 移除元素1.2 合并两个有序数组 2.顺序表的问题及思考3.链表3.1 链表的概念及结构3.2 单链表的实现 4.链表的分类 讲链表之前&#xff0c;我们先看两个顺序表经典算法。 1.顺序表经典算法 1.1 移除元素 经典算法OJ题1&#xff1a;移除…

在服务器上使用Dockerfile创建springboot项目的镜像和踩坑避雷

1. 准备个文件夹 这是我的路径 /usr/local/springboot/docker-daka/docker_files2. 将jar包上传 springboot项目打包——maven的package 这是整个项目打包的模式&#xff0c;也可以分离依赖、配置和程序进行打包&#xff0c;详情看我这篇文章&#xff1a; springboot依赖 配…

java基础 之 集合与栈的使用(四)

文章目录 Queue栈Stack队列和栈的区别小扩展自己写个简单的队列自己写个简单的栈使用栈来实现个队列使用队列来实现个栈写在最后 前文回顾&#xff1a; 戳这里 → java基础 之 集合与栈的使用&#xff08;一&#xff09; 戳这里 → java基础 之 集合与栈的使用&#xff08;二&a…

windows中node版本的切换(nvm管理工具),解决项目兼容问题 node版本管理、国内npm源镜像切换(保姆级教程,值得收藏)

前言 在工作中&#xff0c;我们可能同时在进行2个或者多个不同的项目开发&#xff0c;每个项目的需求不同&#xff0c;进而不同项目必须依赖不同版本的NodeJS运行环境&#xff0c;这种情况下&#xff0c;对于维护多个版本的node将会是一件非常麻烦的事情&#xff0c;nvm就是为…