【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

news2025/1/19 17:00:36

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务

      • 1. 认为主流程code
      • 2. NLP 对话和预测基本均属于分类任务详细见
      • 3. Tensorborad

1. 认为主流程code

import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

###制定参数 --model TextRNN
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()


if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  #TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter,writer)

RNN


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

在这里插入图片描述
TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1339996.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Neural Networks 期刊投稿指南

一 简介 这是国际神经网络学会、欧洲神经网络学会和日本神经网络学会的官方期刊。 论文类型 文章: 原创的、全文长度的文章将被考虑,前提是它们除了摘要形式外尚未发表,并且没有同时在其他地方进行审查。作者可以自愿但不是必须建议一位编辑…

<JavaEE> TCP 的通信机制(二) -- 连接管理(三次握手和四次挥手)

目录 TCP的通信机制的核心特性 三、连接管理 1)什么是连接管理? 2)“三次握手”建立连接 1> 什么是“三次握手”? 2> “三次握手”的核心作用是什么? 3)“四次挥手”断开连接 1> 什么是“…

NineData产品功能重点发布(11月下+12月上)

12 月上半月 1.1 SQL 任务支持 MongoDB 介绍:SQL 任务功能已支持 MongoDB 数据源,可以通过 SQL 任务发起对 MongoDB 的变更申请,支持立即执行或定时执行。 场景: 安全变更:需要对企业成员提交的数据变更进行预审的场…

「Kafka」入门篇

「Kafka」入门篇 基础架构 Kafka 快速入门 集群规划 集群部署 官方下载地址:http://kafka.apache.org/downloads.html 解压安装包: [atguiguhadoop102 software]$ tar -zxvf kafka_2.12-3.0.0.tgz -C /opt/module/修改解压后的文件名称: [a…

哪个品牌的运动耳机比较好?蓝牙无线运动耳机推荐

​在运动时,一副合适的耳机能够让你的运动体验提升到一个新的层次。运动耳机需要具备耐用性、稳定性和优秀的音质,以适应各种运动场景。考虑到这些要求,我将为大家推荐几款在运动场景中表现优异的耳机,它们将是你运动时的理想伴侣…

PowerShell对象——数据的另一个名称

PowerShell对象—数据的另一个名称 实验 要求:需要运行PowerShell v3 或更新版本PowerShell的计算机 任务: 找出生成随机数字的Cmdlet 找出显示当前时间和日期的Cmdlet 任务#2的Cmdlet产生的对象类型是什么?(由Cmdlet产生的对…

【Linux基础开发工具】Linux调试器-gdb

目录 前言 1. 背景 2. 基本使用 总结 前言 GDB(GNU Debugger)是一个功能强大的开源调试器,它用于调试C、C等程序,在Linux环境下软件开发的过程中,调试是一个至关重要的环节。无论是在开发新的软件还是维护现有的代…

linux cuda环境搭建

1,检查驱动是否安装 运行nvidia-smi,如果出现如下界面,说明驱动已经安装 记住cuda版本号 2,安装cudatoolkit 上官网CUDA Toolkit Archive | NVIDIA Developer 根据操作系统选择对应的toolkit 如果已经安装了驱动,选…

Visual Studio 2013 中创建一个基于 Qt 的动态链接库:并在MFC DLL程序中使用

在本地已经安装好 Qt 的情况下,按照以下步骤在 Visual Studio 2013 中创建一个基于 Qt 的动态链接库: 一、新建 Qt 项目: 在 Visual Studio 中,选择 “文件” -> “新建” -> “项目…”。在 “新建项目” 对话框中&#…

性能手机新标杆,一加 Ace 3 发布会定档 1 月 4 日

12 月 27 日,一加宣布将于 1 月 4 日发布新品一加 Ace 3。一加 Ace 系列秉持「产品力优先」理念,从一加 Ace 2、一加 Ace 2V 到一加 Ace 2 Pro,款款都是现象级爆品,得到了广大用户的认可与支持。作为一加 2024 开年之作&#xff0…

立体匹配算法(Stereo correspondence)SGM

SGM(Semi-Global Matching)原理: SGM的原理在wiki百科和matlab官网上有比较详细的解释: wiki matlab 如果想完全了解原理还是建议看原论文 paper(我就不看了,懒癌犯了。) 优质论文解读和代码实现 一位大神自己用c实现…

关于“Python”的核心知识点整理大全44

目录 ​编辑 15.3.4 模拟多次随机漫步 rw_visual.py 注意 15.3.5 设置随机漫步图的样式 15.3.6 给点着色 rw_visual.py 15.3.7 重新绘制起点和终点 rw_visual.py 15.3.8 隐藏坐标轴 rw_visual.py 15.3.9 增加点数 rw_visual.py 15.3.10 调整尺寸以适合屏幕 rw_vi…

介绍几种mfc140u.dll丢失的解决方法,找不到msvcp140.dll要怎么处理

如果你在使用电脑时遇到mfc140u.dll丢失错误时,这可能会导致程序无法正常运行,但是大家不必过于担心。今天的这篇文章本将为你介绍几种mfc140u.dll丢失的解决方法,找不到msvcp140.dll要怎么处理的一些解决方法。 一.mfc140u.dll文件缺失会有什…

Docker自建私人云盘系统

Docker自建私人云盘系统。 有个人云盘需求的人,主要需求有这几类: 文件同步、分享需要。 照片、视频同步需要,尤其是全家人都是用的同步。 影视观看需要(分为家庭内部、家庭外部) 搭建个人网站/博客 云端OFFICE需…

【超图】SuperMap iClient3D for WebGL/WebGPU —— 数据集合并缓存如何控制对象样式

作者:taco 最近在支持的过程中,遇到了一个新问题!之前研究功能的时候竟然没有想到。通常我们控制单个对象的显隐、颜色、偏移的参数都是根据对象所在的图层以及对象单独的id来算的。那么问题来了,合并后的图层。他怎么控制单个对象…

PointNet人工智能深度学习简明图解

PointNet 是一种深度网络架构,它使用点云来实现从对象分类、零件分割到场景语义解析等应用。 它于 2017 年实现,是第一个直接将点云作为 3D 识别任务输入的架构。 本文的想法是使用 Pytorch 实现 PointNet 的分类模型,并可视化其转换以了解模…

【开源】基于JAVA的智能教学资源库系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 课程档案模块2.3 课程资源模块2.4 课程作业模块2.5 课程评价模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 课程档案表3.2.2 课程资源表3.2.3 课程作业表3.2.4 课程评价表 四、系统展示五、核心代…

账号租号平台PHP源码,支持单独租用或合租使用

源码简介 租号平台源码,采用常见的租号模式。 平台的主要功能如下: 支持单独租用或采用合租模式; 采用易支付通用接口进行支付; 添加邀请返利功能,以便站长更好地推广; 提供用户提现功能;…

人工智能 机器学习 深度学习:概念,关系,及区别说明

如果过去几年,您读过科技主题的文章,您可能会遇到一些新词汇,如人工智能(Artificial Intelligence)、机器学习(Machine Learning)和深度学习(Deep Learning)等。这三个词…

【Linux】 last 命令使用

last 命令 用于检索和展示系统中用户的登录信息。它从/var/log/wtmp文件中读取记录,并将登录信息按时间顺序列出。 著者 Miquel van Smoorenburg 语法 last [-R] [-num] [ -n num ] [-adiox] [ -f file ] [name...] [tty...]last 命令 -Linux手册页 选项及作用…