N3 中文文本分类

news2024/11/27 12:34:42
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊# 前言

前言

前面学习了相关自然语言编码,这周进行相关实战

导入依赖库和设置设备

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings

warnings.filterwarnings("ignore")  # 忽略警告
# win10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

这段代码导入了必要的库并设置了设备(GPU或CPU)。

数据预处理和词汇表构建

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS

train_iter = AG_NEWS(split='train')
tokenizer = get_tokenizer('basic_english')  # 返回分词器函数

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])  # 设置默认索引,如果找不到单词,则会选择默认索引

这里使用torchtext库加载AG_NEWS数据集,定义了一个分词器并构建了词汇表。

数据处理管道

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
text_pipeline('here is the an example')

定义了两个数据处理管道:text_pipeline用于将文本转化为词汇表中的索引序列,label_pipeline用于将标签转化为整数索引。

定义数据加载器

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]

    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回维度dim中输入元素的累计和

    return label_list.to(device), text_list.to(device), offsets.to(device)

dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

定义了一个collate_batch函数用于将一个批次的数据整合在一起,并创建了一个数据加载器。

定义模型

from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

定义了一个文本分类模型TextClassificationModel,包括初始化函数、权重初始化和前向传播函数。模型由一个嵌入层和一个线性层组成。

训练和评估函数

import time

def train(dataloader):
    model.train()  # 切换为训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()  # grad属性归零
        loss = criterion(predicted_label, label)  # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                                                  total_acc / total_count, train_loss / total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)  # 计算loss值
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)

    return total_acc / total_count, train_loss / total_count

定义了训练和评估函数,用于训练模型和评估模型性能。

数据集分割和数据加载器创建

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

EPOCHS = 10  # epoch
LR = 5  # 学习率
BATCH_SIZE = 64  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter, test_iter = AG_NEWS()  # 加载数据
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)

split_train_, split_valid_ = random_split(train_dataset,
                                          [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)

加载数据集并将其转换为适用于随机访问的数据集,分割训练集和验证集,并创建相应的数据加载器。

训练和验证模型

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,
                                                        time.time() - epoch_start_time,
                                                        val_acc, val_loss))
    print('-' * 69)

进行训练和验证,在每个epoch结束时打印验证准确率和损失,并根据验证结果调整学习率。

测试模型

print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

在测试集上评估模型性能并打印测试准确率。

结果

在这里插入图片描述

总结

这个案例实现了一个完整的文本分类流程,从数据预处理、模型定义到训练和评估。使用torchtext加载数据,并利用PyTorch构建和训练深度学习模型,实现了对AG_NEWS数据集的文本分类任务,达到了90.1%的精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1825726.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

栈的实现详解

目录 1. 栈1.1 栈的概念及结构1.2 栈的实现方式1.3 栈的应用场景 2. 栈的实现2.1 结构体2.2 初始化2.3 销毁2.4 入栈2.5 出栈2.6 获取栈顶元素2.7 判空2.8 获取个数 3. test主函数4. Stack.c文件5. Stack.h文件6. 运行展示 1. 栈 1.1 栈的概念及结构 栈&#xff1a;一种特殊的…

找不到dll文件如何修复,总结多种dll丢失的修复方法

在计算机使用过程中&#xff0c;我们经常会遇到一些错误提示&#xff0c;其中之一就是“缺少DLL文件”。那么&#xff0c;DLL到底是什么呢&#xff1f;为什么计算机会缺失DLL文件&#xff1f;缺失DLL文件会对电脑产生什么具体影响&#xff1f;本文将详细解析这些问题&#xff0…

UC Berkeley简介以及和Stanford的区别与联系

UC Berkeley Source: Google Map 中文版 UC Berkeley&#xff0c;全称University of California, Berkeley&#xff0c;是一所位于美国加利福尼亚州伯克利市的世界知名公立研究型大学。以下是关于UC Berkeley的详细介绍&#xff1a; 学术声誉和排名 学术声誉&#xff1a; U…

Android入门第69天-AndroidStudio中的Gradle使用国内镜像最强教程

背景 AndroidStudio默认连接的是dl.google的gadle仓库。 每次重新build时: 下载速度慢;等待了半天总时build faild;build到一半connection timeout;即使使用了魔法也难以一次build好;这严重影响了我们的学习、开发效率。 当前网络上的使用国内镜像的教程不全 网上的教程…

056、PyCharm 快速代码重构的方法

在实际的编程过程中&#xff0c;如果有一段代码需要在多个地方重复使用&#xff0c;我们应该将这段代码封装成一个函数。这样可以提高代码的可重用性和可维护性。 在PyCharm编辑器里&#xff0c;可以使用以下操作对代码块进行快速的重构。 &#xff08;1&#xff09;、选中一…

【数据分析】推断统计学及Python实现

各位大佬好 &#xff0c;这里是阿川的博客&#xff0c;祝您变得更强 个人主页&#xff1a;在线OJ的阿川 大佬的支持和鼓励&#xff0c;将是我成长路上最大的动力 阿川水平有限&#xff0c;如有错误&#xff0c;欢迎大佬指正 Python 初阶 Python–语言基础与由来介绍 Python–…

Spring Boot整合Redis通过Zset数据类型+定时任务实现延迟队列

&#x1f604; 19年之后由于某些原因断更了三年&#xff0c;23年重新扬帆起航&#xff0c;推出更多优质博文&#xff0c;希望大家多多支持&#xff5e; &#x1f337; 古之立大事者&#xff0c;不惟有超世之才&#xff0c;亦必有坚忍不拔之志 &#x1f390; 个人CSND主页——Mi…

mac终端:定位于当前文件夹位置Terminal设置快捷键,实现快速启动--m系列

mac自动脚本(f3直接搜或者找其他->自动操作) 然后找到如下功能: 代码如下: on run {input, parameters} tell application "Finder" set currFolder to POSIX path of (folder of the front window as string) end tell tell application "Termina…

人大高瓴/腾讯提出QAGCF:用于QA推荐的图形协同过滤

【摘要】问答(Q&A)平台通常推荐问答对来满足用户的知识获取需求,这与仅推荐单个项目的传统推荐不同。这使得用户行为更加复杂,并为Q&A推荐带来了两个挑战,包括:协作信息纠缠,即用户反馈受问题或答案的影响;以及语义信息纠缠,其中问题与其相应的答案相关联,不同问答对之…

软件工程实务:软件产品

目录 1、软件产品的基本概念 2、软件工程是什么&#xff1f; 为什么产生软件工程? 软件工程是做什么的? 3、定制软件和软件产品的工程比较 4 、软件产品的运行模式 5、软件产品开发时需要考虑的两个基本技术因素 6、产品愿景 7、软件产品管理 8、产品原型设计 9、小结…

C/C++:指针用法详解

C/C&#xff1a;指针 指针概念 指针变量也是一个变量 指针存放的内容是一个地址&#xff0c;该地址指向一块内存空间 指针是一种数据类型 指针变量定义 内存最小单位&#xff1a;BYTE字节&#xff08;比特&#xff09; 对于内存&#xff0c;每个BYTE都有一个唯一不同的编号…

DeepSORT(目标跟踪算法)卡尔曼滤波中的贝叶斯定理

DeepSORT&#xff08;目标跟踪算法&#xff09;卡尔曼滤波中的贝叶斯定理 flyfish 从例子中介绍名词 假设我们有一个袋子&#xff0c;里面有5个红球和3个蓝球。我们从袋子里随机抽取一个球。 概率 (Probability) 我们想计算从袋子里抽到红球的概率 P ( R ) P(R) P(R)。 …

4D毫米波雷达技术及发展

文章目录 前言一、4D毫米波雷达是什么&#xff1f;二、毫米波雷达是什么&#xff1f;毫米波雷达的基本原理多普勒效应 前言 现阶段自动驾驶技术中&#xff0c;主要用到的传感器有摄像头、激光雷达和毫米波雷达。 摄像头的光谱从可见光到红外光谱&#xff0c;是最接近人眼的传感…

useEffect的概念以及使用(对接口)

// useEffect的概念以及使用 import {useEffect, useState} from reactconst Url"http://geek.itheima.net/v1_0/channels"function App() {// 创建状态变量const [lustGet,setLustGet]useState([]);// 渲染完了之后执行这个useEffect(() > {// 额外的操作&#x…

Vector VH6501使用CANoe工程CANDisturbanceMain进行模拟干扰测试

系列文章目录 文章目录 系列文章目录一、文档介绍二、打开工程 CANDisturbanceMain三、模拟干扰3.1 CAN_H或CAN_L短接到地3.2 CAN_H和CAN_L短接3.3 CAN_H或CAN_L短接到电源3.4 CAN_H和CAN_L反接3.5 CAN_H和CAN_L之间的电阻/电容值调整一、文档介绍 本文档主要介绍如何使用CANo…

亚马逊测评自养号与机刷的区别

前言&#xff1a; 在亚马逊运营的领域中&#xff0c;经常有人问&#xff1a;测评自养号就是机刷吗&#xff1f;它们两者有什么区别&#xff1f;做自养号太慢、太需要时间了&#xff0c;如果用机刷的话&#xff0c;会不会简单高效一点&#xff1f; 在这篇文章中&#xff0c;我…

【字符串函数2】

5. strncpy 函数的使用和模拟实现 选择性拷贝 char * strncpy ( char * destination, const char * source, size_t num ); 1.拷贝num个字符从源字符串到目标空间。 2.如果源字符串的⻓度⼩于num&#xff0c;则拷⻉完源字符串之后&#xff0c;在⽬标的后边 追加0 &#…

STM32学习和实践笔记(35):内部温度传感器实验

1.STM32F1内部温度传感器介绍 1.1 STM32F1内部温度传感器简介 STM32F1内部含有一个温度传感器&#xff0c;可用来测量 &#xff08;STM32芯片的&#xff09;CPU 及周围的温度(TA)。&#xff08;实际并不用来测周围的温度&#xff0c;仅用来测试CPU的温度&#xff09; 此温度传…

05.VisionMaster 机器视觉 结果 格式化输出

VisionMaster 机器视觉 结果 格式化输出 格式化工具可以把数据整合并格式化成字符串输出&#xff0c;它既可以链接前面模块的结果输出&#xff0c;也可以直接在框内输入字符格&#xff0c;在进行通信输出前通常用格式化工具将数据进行整理&#xff0c; 如下图所示。 前面的文章…

网络标准架构--OSI七层、四层

OSI七层网络架构&#xff0c;以及实际使用的四层网络架构。