深度学习训练营N1周:Pytorch文本分类入门

news2025/1/23 2:22:34
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

NLP的功能:
在这里插入图片描述

本周使用AG News数据集进行文本分类。实现过程分为前期准备、代码实战、使用测试数据集评估模型和总结四个部分。前期准备中准备环境、加载数据集;代码实战中构建词典、生成数据批次和迭代器,定义模型;使用测试数据集评估模型,对训练好的模型进行测试。

常用函数:

.build_vocab_from_iterator()函数详解
在这里插入图片描述

torchtext.vocab.build_vocab_from_iterator 的作用是从一个可迭代对象中统计token的频次,并返回一个vocab(词汇字典)

上述是官网API接口的定义形式,参数有五个,返回值是Vocab类型实例,五个参数分别是:
●iterator:一个用于创建vocab(词汇字典)的可迭代对象。
●min_freq:最小频数。只有在文本中出现频率大于等于min_freq的token才会被保留下来
●specials:特殊标志,字符串列表。用于在词汇字典中添加一些特殊的token/标记,比如最常用的’',用于代表词汇字典中未存在的token,当然也可以用自己喜欢的符号来代替,具体的意义也取决于用的人。
●special_first:表示是否将specials放到字典的最前面,默认是True
●max_tokens:即限制一下这个词汇字典的最大长度。且这个长度包含的specials列表的长度

代码:


import torch
import torch.nn as nn
import torchvision
import os,PIL,pathlib,warnings
import time
from torchvision import transforms, datasets
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch import nn
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

warnings.filterwarnings("ignore")             #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#数据集
train_iter = AG_NEWS(split='train')      # 加载 AG News 数据集
#构建词典
tokenizer  = get_tokenizer('basic_english') # 返回分词器函数
 
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引

print(vocab(['here', 'is', 'an', 'example']))

text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
print(text_pipeline('here is the an example'))

print(label_pipeline('10'))

from torch.utils.data import DataLoader
 
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    
    for (_label, _text) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        
        # 偏移量,即语句的总词汇量
        offsets.append(processed_text.size(0))
        
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list  = torch.cat(text_list)
    offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和
    
    return label_list.to(device), text_list.to(device), offsets.to(device)

dataloader = DataLoader(train_iter,
                        batch_size=8,
                        shuffle   =False,
                        collate_fn=collate_batch)

 
class TextClassificationModel(nn.Module):
 
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        
        self.embedding = nn.EmbeddingBag(vocab_size,   # 词典大小
                                         embed_dim,    # 嵌入的维度
                                         sparse=False) # 
        
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()
 
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
 
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

num_class  = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)


 
def train(dataloader):
    model.train()  # 切换为训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 500
    start_time   = time.time()
 
    for idx, (label, text, offsets) in enumerate(dataloader):
        
        predicted_label = model(text, offsets)
        
        optimizer.zero_grad()                    # grad属性归零
        loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward()                          # 反向传播
        optimizer.step()  # 每一步自动更新
        
        # 记录acc与loss
        total_acc   += (predicted_label.argmax(1) == label).sum().item()
        train_loss  += loss.item()
        total_count += label.size(0)
        
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count, train_loss/total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()
 
def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0
 
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            
            loss = criterion(predicted_label, label)  # 计算loss值
            # 记录测试数据
            total_acc   += (predicted_label.argmax(1) == label).sum().item()
            train_loss  += loss.item()
            total_count += label.size(0)
            
    return total_acc/total_count, train_loss/total_count


EPOCHS     = 10 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for training
 
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
 
train_iter, test_iter = AG_NEWS() # 加载数据
train_dataset = to_map_style_dataset(train_iter)
test_dataset  = to_map_style_dataset(test_iter)
num_train     = int(len(train_dataset) * 0.95)
 
split_train_, split_valid_ = random_split(train_dataset,
                                          [num_train, len(train_dataset)-num_train])
 
train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader  = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
 
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch,
                                           time.time() - epoch_start_time,
                                           val_acc,val_loss))
 



# 评估模型
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))





本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/600509.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

chatgpt赋能python:Python冒泡排序详解

Python冒泡排序详解 介绍 Python是一门强大的编程语言&#xff0c;它在数据科学、机器学习、Web开发等领域都有广泛的应用。其中&#xff0c;排序算法是编程中一个重要的话题&#xff0c;冒泡排序也是最基本的排序算法之一。本文将详解Python冒泡排序的实现方法和优化技巧&am…

chatgpt赋能python:利用Python编写模拟器:一种循序渐进的方法

利用Python编写模拟器&#xff1a;一种循序渐进的方法 模拟器是一种用于模拟计算机硬件或软件的程序。它模拟了真实设备的功能&#xff0c;可以帮助开发人员进行测试和调试&#xff0c;以及提供一种环境来设计和验证新的算法和协议。Python是一种广泛使用的编程语言&#xff0…

计讯物联宝贝王手工大赛投票结果正式揭晓,速速围观!

在孩子的想象世界中&#xff0c; 生活中的可爱 可以是专属六一的蛋糕&#xff0c; 可以是创意手绘手摇扇&#xff0c; 可以是萌萌可爱的花束&#xff0c; 可以是未来超智能机器人&#xff0c; 可以是无人航天器模型…… 他们的想象&#xff0c; 是尚未被世俗沾染的赤忱之…

【i阿极送书——第三期】《Hadoop大数据技术基础与应用》

系列文章目录 作者&#xff1a;i阿极 作者简介&#xff1a;Python领域新星作者、多项比赛获奖者&#xff1a;博主个人首页 &#x1f60a;&#x1f60a;&#x1f60a;如果觉得文章不错或能帮助到你学习&#xff0c;可以点赞&#x1f44d;收藏&#x1f4c1;评论&#x1f4d2;关注…

病毒分析丨一款注入病毒

作者丨黑蛋 一、病毒简介 SHA256: de2a83f256ef821a5e9a806254bf77e4508eb5137c70ee55ec94695029f80e45 MD5: 6e4b0a001c493f0fcf8c5e9020958f38 SHA1: bea213f1c932455aee8ff6fde346b1d1960d57ff 云沙箱检测&#xff1a; 二、环境准备 系统 Win7x86Sp1 三、行为监控 打开…

基于GD32开发板的GPS定位模块的使用操作

基于上一章的介绍&#xff0c;本章将介绍如何基于gd32开发板使用gps定位模块。 一、官方代码分析 正点原子的官方测试例程&#xff0c;测试代码的逻辑还是比较简单的&#xff0c;主要就是先调用函数atk_mo1218_init()进行初始化&#xff0c;接着就调用 SkyTraq binary 协议的 A…

mac host学习

参考&#xff1a; SSH中known_hosts文件作用和常见问题及解决方法 https://blog.csdn.net/luduoyuan/article/details/130070120在 Mac 上更改 DNS 设置 https://support.apple.com/zh-cn/guide/mac-help/mh14127/mac mac中有时候你输入的域名&#xff0c;但会跳转到与期望ip不…

Arduino UNO用L9110 电机驱动模块驱动两个直流电机

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、简介二、工作原理三、使用前准备四、测试方法五、实验现象 一、简介 L9110电机驱动模块采用推挽式功率放大&#xff0c;设有固定安装孔&#xff0c;适合组装&a…

Linux常用命令——grub命令

在线Linux命令查询工具 grub 多重引导程序grub的命令行shell工具 补充说明 grub命令是多重引导程序grub的命令行shell工具。 语法 grub(选项)选项 --batch&#xff1a;打开批处理模式&#xff1b; --boot-drive<驱动器>&#xff1a;指定stage2的引导驱动器&#x…

霸榜第一框架:工业检测,基于差异和共性的半监督方法用于图像表面缺陷检测...

关注并星标 从此不迷路 计算机视觉研究院 公众号ID&#xff5c;ComputerVisionGzq 学习群&#xff5c;扫码在主页获取加入方式 论文地址&#xff1a;https://arxiv.org/ftp/arxiv/papers/2205/2205.00908.pdf 链接: https://pan.baidu.com/s/1ar2BN1p2jJ-cZx1J5dGRLg 密码: 2l…

Learn From Microsoft Build Ⅲ:低代码

点击蓝字 关注我们 编辑&#xff1a;Alan Wang 排版&#xff1a;Rani Sun 微软 Reactor 为帮助广开发者&#xff0c;技术爱好者&#xff0c;更好的学习 .NET Core, C#, Python&#xff0c;数据科学&#xff0c;机器学习&#xff0c;AI&#xff0c;区块链, IoT 等技术&#xff0…

使用神经网络合成数据生成技术实现电力系统无人机自动巡检

使用神经网络合成数据生成技术实现电力系统无人机自动巡检 美国能源公司 Exelon 正在利用神经网络合成数据生成技术&#xff0c;为电力系统无人机自动巡检项目提供支持。这一技术有助于提高巡检效率和准确性&#xff0c;降低人力和时间成本。 1. 电力系统巡检的挑战 电力系统…

基于知识图谱表示学习的谣言早期检测方法

源自&#xff1a;电子学报 作者&#xff1a;皮德常 吴致远 曹建军 摘 要 社交网络谣言是严重危害社会安全的一个重要问题.目前的谣言检测方法基本上都依赖用户评论数据.为了获取可供模型训练的足量评论数据&#xff0c;需要任由谣言在社交平台上传播一段时间&#xff0c;这…

手机安卓Termux安装MySQL数据库【公网远程数据库】

文章目录 前言1.安装MariaDB2.安装cpolar内网穿透工具3. 创建安全隧道映射mysql4. 公网远程连接5. 固定远程连接地址 转载自cpolar极点云的文章&#xff1a;Android Termux安装MySQL数据库 | 公网安全远程连接【Cpolar内网穿透】 前言 Android作为移动设备&#xff0c;尽管最初…

Android 和 ktor 的 HTTP 块请求

Android 和 ktor 的 HTTP 块请求 在这篇非常短的文章中&#xff0c;我将简要解释什么是块或流式 HTTP 请求&#xff0c;使用它有什么好处&#xff0c;以及它在 Android 中的工作原理。 Android 应用程序使用 HTTP 请求从后端下载数据。此信息在应用程序上存储和处理以使其正常…

计算机内存取证之BitLocker恢复密钥提取还原

BitLocker是微软Windows自带的用于加密磁盘分卷的技术。 通常&#xff0c;解开后的加密卷通过Windows自带的命令工具“manage-bde”可以查看其恢复密钥串&#xff0c;如下图所示&#xff1a; 如图&#xff0c;这里的数字密码下面的一长串字符串即是下面要提取恢复密钥。 在计…

chatgpt赋能python:Python编程教程之抽签程序

Python编程教程之抽签程序 介绍 对于喜欢玩抽签、体育彩票等游戏的人来说&#xff0c;抽签程序是一款非常有用的小工具。抽签程序可以用来随机抽取一定数量的幸运儿&#xff0c;而且运行速度快&#xff0c;结果随机性高&#xff0c;不需要人工干预。 那么&#xff0c;Python…

Spring Boot 3.1 中如何整合Spring Security和Keycloak

在今年2月14日的时候&#xff0c;Keycloak 团队宣布他们正在弃用大多数 Keycloak 适配器。其中包括Spring Security和Spring Boot的适配器&#xff0c;这意味着今后Keycloak团队将不再提供针对Spring Security和Spring Boot的集成方案。 但是&#xff0c;如此强大的Keycloak&am…

一文搞懂Android动画

这里写目录标题 前言一、视图动画1. 补间动画---Animation抽象类动画1.1 AlphaAnimation&#xff1a;控制一个对象透明度的动画。1.1.1 xml实现示例1.1.2 java实现示例 1.2 RotateAnimation&#xff1a;控制一个对象旋转的动画。1.1.1 xml实现示例1.1.2 java实现示例 1.3 Scale…

Linux中查看端口被哪个进程占用、进程调用的配置文件、目录等

1.查看被占用的端口的进程&#xff0c;netstat/ss -antulp | grep :端口号 2.通过上面的命令就可以列出&#xff0c;这个端口被哪些应用程序所占用&#xff0c;然后找到对应的进程PID 3.根据PID查询进程。如果想详细查看这个进程&#xff0c;PID具体是哪一个进程&#xff0c;可…