N4中文分类

news2024/11/16 23:40:23
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊# 前言

前言

上周学习了英文文本分类,这次进行中文分类实战。

1. 数据读取

import pandas as pd
train_data = pd.read_csv('train.csv',sep='\t',header=None)
print(train_data.head())

读取包含训练数据的 CSV 文件。数据被存储在 train_data 中,并输出前几行进行检查。

2. 自定义数据迭代器

def coustom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y

train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])

定义了一个生成器函数 coustom_data_iter,用于迭代文本和标签。

3. 词汇表构建与文本处理

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba

tokenizer = jieba.lcut

def yield_tokens(data_iter):
    for text, _ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

使用 Jieba 分词器对文本进行分词,然后从分词结果中构建词汇表 vocab<unk> 表示未知词。

4. 数据加载与批处理

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]

    for (_text, _label) in batch:
        label_list.append(label_pipeline(_label))

        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)

        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)

    return text_list.to(device), label_list.to(device), offsets.to(device)

dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

定义了一个批处理函数 collate_batch,用于将文本和标签转换成适合模型输入的格式,并创建数据加载器 dataloader

5. 模型定义

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()

        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

一个简单的文本分类模型,包含嵌入层 nn.EmbeddingBag 和全连接层 nn.Linear

6. 训练与评估函数

def train(dataloader):
    model.train()
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 50
    start_time = time.time()

    for idx, (text, label, offsets) in enumerate(dataloader):

        predicted_label = model(text, offsets)

        optimizer.zero_grad()
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d}|{:4d}/{:4d} batches'
                  '| train_acc {:4.3f} train_loss{:4.5f}'.format(epoch, idx, len(dataloader),
                                                                 total_acc / total_count, train_loss / total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (text, label, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)

            loss = criterion(predicted_label, label)

            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)

    return total_acc / total_count, train_loss / total_count

这部分代码定义了训练函数 train 和评估函数 evaluate。训练函数执行模型的前向和后向传播以及参数更新,而评估函数用于在验证数据集上评估模型性能。

7. 数据集划分与训练

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

EPOCHS = 10
LR = 5
BATCH_SIZE = 64

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset,
                                          [int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    lr = optimizer.state_dict()['param_groups'][0]['lr']

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('|epoch {:1d} |time: {:4.2f}s | '
          'valid_acc{:4.3f} valid_loss{:4.3f} | lr{:4.6f}'.format(epoch,
                                                                  time.time() - epoch_start_time,
                                                                  val_acc, val_loss, lr))
    print('-' * 69)

将数据集分为训练集和验证集,并创建对应的数据加载器。然后进行模型训练和验证,使用调度器根据验证集的表现调整学习率。

8. 模型预测

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"

model = model.to("cpu")
print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline)])

在这里插入图片描述

这部分代码定义了一个预测函数 predict,用于对输入文本进行分类预测。最后,示例文本被传入预测函数,并输出预测的类别。

结果

在这里插入图片描述

总结

整个项目实现了一个文本分类任务,流程包括数据读取、分词和词汇表构建、数据加载与批处理、模型定义、训练与验证、最终精度达89%。这周熟悉文字信息的处理流程,为以后学习做了好的铺垫。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1839401.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

qt.qpa.xcb: could not connect to display问题解决

1、问题描述 以服务器pi5作为远程解释器&#xff0c;本地win11使用vscode远程调试视觉时报错如下&#xff1a; qt.qpa.xcb: could not connect to display qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "xxxxx" even though it was …

室内灰尘对老人小孩危害不容忽视,资深家政推荐除灰尘空气净化器

正所谓“病从口入&#xff0c;尘从窗入”&#xff0c;室内灰尘问题不容小觑。尤其是对老人和小孩来说&#xff0c;灰尘中的有害物质更是威胁健康的重要因素。近期天气炎热&#xff0c;家家户户每天都会开窗通风&#xff0c;然而这也带来了灰尘和毛絮的问题。即使每天打扫&#…

java 线程之间通信-volatile 和 synchronized

你好&#xff0c;我是 shengjk1&#xff0c;多年大厂经验&#xff0c;努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注&#xff01;你会有如下收益&#xff1a; 了解大厂经验拥有和大厂相匹配的技术等 希望看什么&#xff0c;评论或者私信告诉我&#xff01; 文章目录 一…

LabVIEW Windows与RT系统的比较与选择

LabVIEW是一种系统设计和开发环境&#xff0c;广泛应用于各类工程和科学应用中。LabVIEW Windows和LabVIEW RT&#xff08;Real-Time&#xff09;是LabVIEW的两个主要版本&#xff0c;分别适用于不同的应用场景。以下从多个角度详细分析两者的区别&#xff0c;并提供选择建议。…

国际期货行情相关术语

1&#xff09;合约&#xff1a;期货行情表提供了期货交易的相关信息 &#xff0c;行情表中每一个期货合约都有合约代码&#xff08;由期货合约交易代码和合约到期月份组成&#xff09;来标识。 &#xff08;2&#xff09;开盘价&#xff1a;当日某一期货合约交易开始前五分钟集…

开发者配置项、开发者选项自定义

devOptions.vue源码 <!-- 开发者选项 &#xff08;CtrlAltShiftD&#xff09;--> <template><div :class"$options.name" v-if"visible"><el-dialog:custom-class"sg-el-dialog":append-to-body"true":close-on…

安装pytorch环境

安装&#xff1a;Anaconda3 通过命令行查显卡nvidia-smi 打开Anacanda prompt 新建 conda create -n pytorch python3.6 在Previous PyTorch Versions | PyTorch选择1.70&#xff0c;安装成功&#xff0c;但torch.cuda.is_available 返回false conda install pytorch1.7.0…

银行数仓项目实战(四)--了解银行业务(存款)

文章目录 项目准备存款活期定期整存整取零存整取存本取息教育储蓄定活两便通知存款 对公存款对公账户协议存款 利率 项目准备 &#xff08;贴源层不必写到项目文档&#xff0c;因为没啥操作没啥技术&#xff0c;只是数据。&#xff09; 可以看到&#xff0c;银行的贴源层并不紧…

keil MDK自动生成带版本bin文件

作为嵌入式单片机开发&#xff0c;在Keil MDK&#xff08;Microcontroller Development Kit&#xff09;中开发完编译完后&#xff0c;经常需要手动进行版本号添加用于发版&#xff0c;非常麻烦&#xff0c;如果是对外发行的话&#xff0c;更容易搞错&#xff0c;特此码哥提供一…

基于若依的ruoyi-nbcio流程管理系统增加所有任务功能(一)

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码&#xff1a; https://gitee.com/nbacheng/ruoyi-nbcio 演示地址&#xff1a;RuoYi-Nbcio后台管理系统 http://218.75.87.38:9666/ 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码&#xff1a; h…

最好用的智能猫砂盆存在吗?自用分享智能猫砂盆测评!

在现代都市的忙碌生活中&#xff0c;作为一名上班族&#xff0c;经常因为需要加班或频繁出差而忙碌得不可开交。急匆匆地出门&#xff0c;却忘了给猫咪及时铲屎。但是大家要知道&#xff0c;不及时清理猫砂盆会让猫咪感到不适&#xff0c;还会引发各种健康问题&#xff0c;如泌…

无问芯穹Qllm-Eval:制作多模型、多参数、多维度的量化方案

前言 近年来&#xff0c;大语言模型&#xff08;Large Models, LLMs&#xff09;受到学术界和工业界的广泛关注&#xff0c;得益于其在各种语言生成任务上的出色表现&#xff0c;大语言模型推动了各种人工智能应用&#xff08;例如ChatGPT、Copilot等&#xff09;的发展。然而…

2021 hnust 湖科大 C语言课程设计报告+代码+流程图源文件+指导书

2021 hnust 湖科大 C语言课程设计报告代码流程图源文件指导书 目录 报告 下载链接 https://pan.baidu.com/s/14NFsDbT3iS-a-_7l0N5Ulg?pwd1111

D111FCE01LC2NB70带流量调节派克比例阀

D111FCE01LC2NB70带流量调节派克比例阀 派克比例阀&#xff1a;由于采用&#xff08;秉圣135陈工6653询3053&#xff09;电液混合控制技术&#xff0c;响应速度更快、精度更高、控制更平稳。同时&#xff0c;由于采用高质量的材料制造&#xff0c;具有较高的承压能力和抗磨损性…

生活实用口语柯桥成人外语培训机构“客服”用英文怎么说?

● 01. “客服”英语怎么说&#xff1f; ● 我们都知道“客服”就是“客户服务”&#xff0c; 所以Customer Service就是#15857575376客服的意思。 但是这里的“客服”指代的不是客服人员&#xff0c; 而是一种Service服务。 如果你想要表达客服人员可以加上具体的职位&a…

面向对象前置(java)

文章目录 环境配置相关如何在 cmd 任何路径下中 使用 javac(编译) 和 java(运行) 指令path变量的含义javac(编译&#xff09;使用java(运行&#xff09;的使用public class 和 class 的区别 标识符命名规则命名规范 字面量字符串拼接变量的作用域转移字符类型转换接收用户键盘输…

Adobe XD是否收费?试试这几款超值的免费软件吧!

Adobe XD是一站式的 UX/UI 设计平台&#xff0c;设计师可以使用Adobe XD完成移动应用app界面设计、网页设计、原型设计等。Adobe XD也是一款结合原型和设计&#xff0c;提供工业性能的跨平台设计产品。而Adobebe。 XD跨平台的特点得到了很好的弥补 Sketch 没有 Windows 版本的缺…

【Java基础】 线程状态转化

Java 中的线程状态转换是指线程在其生命周期中可以经历的不同状态以及这些状态之间的转换。了解线程的状态转换对于有效地管理和调试多线程应用程序非常重要。Java 提供了 Thread.State 枚举来描述线程的状态。 状态 NEW&#xff08;新建&#xff09;&#xff1a; 线程被创建&…

STemWin在Windows上仿真运行环境配置

文章目录 一、STemWin介绍二、emWin必用的2个工具2.1 PC仿真器2.2 GUIBuilder图形化设计工具三、安装VS2022四、打开emwin仿真工程五、常见的配置修改5.1 运行内存修改5.2 LCD显示屏尺寸的修改一、STemWin介绍 emWin(Embedded Wizard Graphics Library)是Segger公司开发的嵌…

深入二进制安全:全面解析Protobuf

前言 近两年&#xff0c;Protobuf结构体与Pwn结合的题目越来越多。 23年和24年Ciscn都出现了Protobuf题目&#xff0c;24年甚至还出现了2道。 与常规的Pwn题利用相比&#xff0c;只是多套了一层Protobuf的Unpack操作。 本文包含Protobuf环境安装、相关语法、编译运行以及pb…