RNN实战

news2024/11/17 17:40:00

本主要是利用RNN做多分类任务,在熟悉RNN训练的过程中,我们可以理解
1)超参数 batch_size和pad_size对训练过程的影响。
2)文本处理过程中是如何将文本的文字表示转化为向量表示
3)RNN梯度消失和序列长度的关系
4)利用pytorch如何训练一个网络模型以及保存和加载
5)理解多分类任务中的混淆矩阵

数据集HUCNews中抽取了20万条新闻标题,文本长度在20到30之间。一共10个类别,每类2万条。
类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。

数据集划分

数据集数据量
训练集18万
验证集1万
测试集1万

重要参数如下

self.dropout = 0.3  # 随机失活
self.num_epochs = 7  # epoch数
self.batch_size = 256  # batch size
self.pad_size = 7  # 每句话处理成的长度(短填长切)
self.learning_rate = 1e-3  # 学习率
self.hidden_size = 128  # rnn隐藏层
self.num_layers = 2  # rnn层数,注意RNN中的层数必须大于1,dropout才会生效

RNN.py 模型文件,主要是配置文件和RNN网络模型定义。

# coding: UTF-8
import torch
import torch.nn as nn
import numpy as np


class Config(object):
    """配置参数"""

    def __init__(self, dataset, embedding):
        self.model_name = 'RNN'
        self.train_path = dataset + '/data/train.txt'  # 训练集
        self.dev_path = dataset + '/data/dev.txt'  # 验证集
        self.test_path = dataset + '/data/test.txt'  # 测试集
        self.class_list = [x.strip() for x in open(
            dataset + '/data/class.txt', encoding='utf-8').readlines()]  # 类别名单
        self.vocab_path = dataset + '/data/vocab.pkl'  # 词表
        self.save_path = dataset + '/saved_dict/' + self.model_name + 'ckpt'  # 模型训练结果
        self.log_path = dataset + '/log/' + self.model_name
        self.embedding_pretrained = torch.tensor(
            np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32')) \
            if embedding != 'random' else None  # 预训练词向量
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备

        self.dropout = 0.3  # 随机失活
        self.require_improvement = 10000  # 若超过10000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)  # 类别数
        self.n_vocab = 0  # 词表大小,在运行时赋值
        self.num_epochs = 7  # epoch数
        self.batch_size = 256  # batch size
        self.pad_size = 7  # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3  # 学习率
        self.embed = self.embedding_pretrained.size(1) \
            if self.embedding_pretrained is not None else 300  # 字向量维度, 若使用了预训练词向量,则维度统一
        self.hidden_size = 128  # rnn隐藏层
        self.num_layers = 2  # rnn层数,注意RNN中的层数必须大于1,dropout才会生效


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.rnn = nn.RNN(config.embed, config.hidden_size, config.num_layers,
                          batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        # 将原始数据转化成密集向量表示 [batch_size, seq_len, embedding]
        out = self.embedding(x[0])
        out, hidden_ = self.rnn(out)
        # out[:, -1, :] seq_len最后时刻的输出等价 hidden_
        out = self.fc(out[:, -1, :])
        return out

run_rnn.py文件,主程序入口,指定运行参数以及文本加载过程,最后调用train_eval.py的train函数进行模型训练。

import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from utils import build_dataset, build_iterator, get_time_dif

parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', default='RNN', type=str, required=True)
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()

if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model
    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)

    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True

    start_time = time.time()
    print("Loading data...")
    # args.word 分词方式, True是词级别,默认是False
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    # build_iterator返回格式 [([词/字在词典中的位置] ,label, len(word)), ...]
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)

    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)
    # len(vocab)="<PAD>", len(vocab) -1 ="<UNK>"
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    init_network(model)
    print(model.parameters)

    train(config, model, train_iter, dev_iter, test_iter)

train_eval.py 文件,主要对模型参数进行初始化,函数train主要是从自定义迭代器中加载数据进行训练。test函数是在模型训练完后对测试数据集进行测试。evaluate函数主要是在训练过程中对验证集数据进行验证。

# coding: UTF-8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt


# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():
        if exclude not in name:
            if 'weight' in name:
                if method == 'xavier':
                    nn.init.xavier_normal_(w)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(w)
                else:
                    nn.init.normal_(w)
            elif 'bias' in name:
                nn.init.constant_(w, 0)
            else:
                pass

def train(config, model, train_iter, dev_iter, test_iter):
    loss_list = []
    start_time = time.time()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    # dev_acc_list = []
    # dev_loss_list = []
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        for i, (trains, labels) in enumerate(train_iter):
            outputs = model(trains)
            # 打印tensor的所有数据
            # torch.set_printoptions(threshold=float('inf'))
            model.zero_grad()
            loss = F.cross_entropy(outputs, labels)
            loss_list.append(loss.detach().numpy())
            loss.backward()
            optimizer.step()
            if total_batch % 100 == 0:
                true = labels.data.cpu()
                # 取出每一行最大的那个概率的索引值
                predic = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predic)
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                # dev_acc_list.append(dev_acc)
                # dev_loss_list.append(dev_loss)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                time_dif = get_time_dif(start_time)
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
                writer.add_scalar("loss/train", loss.item(), total_batch)
                writer.add_scalar("loss/dev", dev_loss, total_batch)
                writer.add_scalar("acc/train", train_acc, total_batch)
                writer.add_scalar("acc/dev", dev_acc, total_batch)
                model.train()
            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过10000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
    writer.close()

    size = len(loss_list)
    x_axis = [i for i in range(0, size)]
    plt.plot(x_axis, loss_list, color='red')
    plt.show()

    test(config, model, test_iter)


def test(config, model, test_iter):
    model.load_state_dict(torch.load(config.save_path))
    model.eval()
    start_time = time.time()
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'
    print(msg.format(test_loss, test_acc))
    print("Precision, Recall and F1-Score...")
    print(test_report)
    print("Confusion Matrix...")
    print(test_confusion)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)


def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    # 模型评估的时候无梯度模式
    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predict = torch.max(outputs.data, 1)[1].cpu().numpy()

            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predict)

    acc = metrics.accuracy_score(labels_all, predict_all)

    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    # 用于训练过程中的验证
    return acc, loss_total / len(data_iter)

model_test.py 是单个文本的推理文件。

utils.py定义了加载数据集函数load_dataset,自定义迭代器将数据转化为tensor格式便于输入到模型。

完整代码github地址

项目结构清晰以后我们主要要记录一下,RNN训练过程中遇到的一些问题,尽管现在已经不怎么使用RNN网络模型了,不过这不影响RNN在时序网络中的地位(LSTM 长短时记忆网络、GRU门控循环单元都是RNN的优化)我们还是有必要好好认识一下RNN的训练过程,以及超参数对损失值的影响。

我们主要参数设置如下,我们只对batch_size和pad_size进行修改看一下模型的损失下降曲线。

self.dropout = 0.3  # 随机失活
self.require_improvement = 10000  # 若超过10000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list)  # 类别数
self.n_vocab = 0  # 词表大小,在运行时赋值
self.num_epochs = 7  # epoch数
self.batch_size = 64  # batch size
self.pad_size = 32  # 每句话处理成的长度(短填长切)
self.learning_rate = 1e-3  # 学习率
self.embed = self.embedding_pretrained.size(1) \
if self.embedding_pretrained is not None else 300  # 字向量维度
self.hidden_size = 128  # rnn隐藏层
self.num_layers = 2  # rnn层数,注意RNN中的层数必须大于1,dropout才会生效

batch_size = 64 pad_size = 32 learning_rate = 1e-3

训练过程
在这里插入图片描述

损失函数结果图,可以看出根本就不收敛,pad_size值过大,可能出现出现梯度消失,导致模型参数根本就不更新。

在这里插入图片描述
batch_size = 64 pad_size = 16 learning_rate = 1e-3

训练过程
在这里插入图片描述
从这里足以感性的理解为什么很多人说RNN携带的时序信息走不远,当我们将时序长度pad_size设置16时(其他参数不变)可以看到验证数据集的准确度和损失都还不错的,比pad_size=32要好很多,至少可以知道模型的参数是在更新,且损失值也有下降的趋势。
在这里插入图片描述
混淆矩阵也还可以。 混淆矩阵参考


以上是文本序列长度pad_size对RNN训练的影响。现在我们来看下batch_size大小对RNN训练的影响。为了让模型收敛pad_szie统一取16

batch_size = 128 pad_size = 16 learning_rate = 1e-3

训练过程
在这里插入图片描述

batch_size变大为128更新次数少,每一次迭代考虑的样本更多。每次迭代考虑的样本大了以后,梯度优化的波动变小,下降更平滑。相比batch_size=64,损失图像下下降确实更平滑。混淆矩阵无太大差异。
在这里插入图片描述
batch_size = 256 pad_size = 16 learning_rate = 1e-3

训练过程

在这里插入图片描述
batch_size=256损失值下降更平滑,收敛速度更快,batch_size=64时训练时长在18min左右,而此参数下训练时长仅要5min左右。
在这里插入图片描述
batch_size = 1024 pad_size = 16 learning_rate = 1e-3

训练过程


batch_size=1024时收敛速度更快,而此参数下训练时长仅要2min左右。
在这里插入图片描述

混淆矩阵,可以看出在显存足够大的情况下适当增大batch_size可以达到两点效果1)加快训练的收敛的速度 2)梯度优化的波动减小,收敛过程更加平滑。

在这里插入图片描述

至此我们已经完成了RNN训练中两个比较重要的超参数batch_size和pad_size对训练过程的影响。还有很多其他的超参数这里就不实验了。

pad_size由32变成16时候,显然只用到了一半的数据信息,无论怎么进行超参数的优化都不可能达到最好的结果。如果使用32又会出现梯度消失,从而模型不收敛。LSTM模型就有效的改进了这个缺陷。下一篇文章我们使用同样的超参数和数据集构造一个LSTM模型实验这个改进有多大。

参考
https://github.com/649453932/Chinese-Text-Classification-Pytorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1509582.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

企商在线CTO楼炜:论云计算与产业互联网

024年全国两会召开之际&#xff0c;3月4日&#xff0c;全国政协委员、京东集团技术委员会主席曹鹏提交了《发挥产业互联网平台作用 打造实体产业数字化转型直效通道》提案&#xff0c;提出了产业互联网平台在整合供应链、资金、技术、资讯、培训、人才等各类资源的重要作用。云…

python实现生成树

生成树 生成树&#xff08;Spanning Tree&#xff09;是一个连通图的生成树是图的极小连通子图&#xff0c;它包含图中的所有顶点&#xff0c;并且只含尽可能少的边。这意味着对于生成树来说&#xff0c;若砍去它的一条边&#xff0c;则会使生成树变成非连通图&#xff1b;若给…

ChatGpt只能看,但无法发送消息的解决办法

这几天发现chatgpt没法发送消息了,我以为是网络问题,又过了几天还是不能发,我以为是梯子的问题,可给我急坏了,于是我用无痕模式发现可以访问额. 但是无痕模式毕竟不是长久之计,于是找到了一个方法 1.首先把电脑缓存全清除了 第一种方法: 快捷键是 : ctrlshiftdel (这会吧浏览…

电脑切屏卡顿,尤其是打游戏时切屏卡顿问题解决方法

博主在打游戏时喜欢切后台但是最近发现切屏尤其慢&#xff0c;异常卡顿&#xff0c;但是是新换的电脑&#xff0c;所以苦恼了半天&#xff0c;上网搜也没有结果&#xff0c;说的都是些配置低&#xff0c;系统文件损坏等问题&#xff0c;所以再检查分辨率时发现问题所在 屏幕分辨…

Visual Studio 2022 配置“Debug|x64”的 Designtime 生成失败。IntelliSense 可能不可用。

今天写代码&#xff0c;无缘无故就给我整个这个错误出来&#xff0c;我一头雾水。 经过我几个小时的奋战&#xff0c;终于解决问题 原因就是这个Q_INTERFACES(&#xff09;宏&#xff0c;我本想使用Q_DECLARE_INTERFACE Q_INTERFACES这两个Qt宏实现不继承QObject也能使用qobjec…

jmeter压测实战

1,设置HTTP请求默认值 2,设置全局变量 3,新建线程组 4,设置私钥 5,每个接口新建一个事务控制器 6,新建Java请求 对于有sign签名的需要将jar包放在apache-jmeter-5.4.1\apache-jmeter-5.4.1\lib\ext目录下,然后引入进来。 除此之外,还需要下载bouncycastle.jar包放在…

地表径流量分布数据/水文站点分布数据

天然河川径流资料对于认识水文自然规律、国家水资源可持续利用以及适应气候变化政策制定具有重要意义。我国现有的天然河川径流资料存在时间缺失率高、水文站点密度不足等问题&#xff0c;在年际和季节变化尺度上存在较大的流量偏差。 引言 大气降水落到地面后&#xff0c;一部…

【数据分析】数据分析介绍

专栏文章索引&#xff1a;【数据分析】专栏文章索引 目录 一、介绍 二、生活中的数据分析 1.无处不在的数据 2.为什么要进行数据分析&#xff1f; 三、数据挖掘案例 1.案例分析 一、介绍 数据采集&#xff1a;数据采集是指从不同来源收集原始数据的过程&#xff0c;包括…

golang学习随便记16-反射

为什么需要反射 下面的例子中编写一个 Sprint 函数&#xff0c;只有1个参数&#xff08;类型不定&#xff09;&#xff0c;返回和 fmt.Fprintf 类似的格式化后的字符串。实现方法大致为&#xff1a;如果参数类型本身实现了 String() 方法&#xff0c;那调用 String() 方法即可…

web | http 的一些问题 | get/post的区别 | http版本 | http与https的区别 | session、cookie、token

怎么来说呢&#xff1f;这应该算一个大类了&#xff0c;基本上设计网络的应用层 当然重要的是从网络层----->应用层 &#xff08;杠精勿杠&#xff0c;知道中间还有其他层&#xff09; 先来讲一下http的结构 都知道http 有三部分&#xff0c;头部、请求头和body 头部&#x…

51单片机基础篇系列-点亮一个LED发光管基础知识搭建

&#x1f308;个人主页: 会编辑的果子君 &#x1f4ab;个人格言:“成为自己未来的主人~” LED发光二极管 它是半导体二极管的一种&#xff0c;可以把电能转化成光能&#xff0c;常简写为LED&#xff0c;发光二极管与普通二极管一样是由一个PN结组成&#xff0c;也具有单向…

Jenkins Pipeline实现Golang项目的CI/CD

Jenkins Pipeline实现Golang项目的CI/CD 背景 最近新增了一个Golang实现的项目&#xff0c;需要接入到现有的流水线架构中。 流程图 这边流程和之前我写过的一篇《基于Jenkins实现的CI/CD方案》差不多&#xff0c;不一样的是构建现在是手动触发的&#xff0c;没有配置webho…

dolphin schedulerAPI调用(二)——创建任务

&#xff08;作者&#xff1a;陈玓玏&#xff09; API文档地址&#xff1a;http://192.168.3.100:21583/dolphinscheduler/swagger-ui/index.html?languagezh_CN&langcn#/task%20definition%20related%20operation/createTaskDefinitionUsingPOST_1 实际使用中&#x…

微信小程序H5设置全局弹窗

微信小程序&H5设置全局弹窗 微信小程序&H5设置全局弹窗效果图1、下载所需库2、创建vue.config.js 文件3、创建全局公告组件头部公告组件弹窗公告组件4、组件注册到全局5、在pages.json文件中配置 insetLoader6、H5需要额外使用render.js7、全局调用(一进入页面就获取弹…

Elasticsearch:使用标记修剪提高文本扩展性能

作者&#xff1a;来自 Elastic Kathleen DeRusso 本博客讨论了 ELSER 性能的令人兴奋的新增强功能&#xff0c;该增强功能即将在 Elasticsearch 的下一版本中推出&#xff01; 标记&#xff08;token&#xff09;修剪背后的策略 我们已经详细讨论了 Elasticsearch 中的词汇和…

《系统架构设计师教程(第2版)》第6章-数据库设计基础知识-02-关系数据库

文章目录 1. 基本概念1.1 基本术语属性 (Attribute)域 (Domain)元数&#xff08;Arity&#xff09; / 目 &#xff08;Cardinality&#xff09;/ 度 (Degree)元组候选码 (Candidate Key)主码 (Primary Key)主属性 (Prime Attribute)外码 (Foreign Key)全码 (All-key)笛卡尔积 1…

大数据队列Kafka

了解什么是kafka之前&#xff0c;首先要了解一下什么是消息队列 一丶kafka的基本概述 消息队列&#xff1a;MQ介绍 定义 官方定义&#xff1a;消息队列是一种异步的服务间通信方式,是分布式系统中重要的组件,主要解决应用耦合,异步消息,流量削锋等问题,实现高性能,高可用,可伸…

WPF 中集合 ObservableCollection<T>的使用

C#集合类ObservableCollection<T> 类似于泛型列表类List<T>&#xff0c;表示一个动态数据收集&#xff0c;该集合在添加或删除项或刷新整个列表时提供通知。 所在命名空间&#xff1a;System.Collections.ObjectModel 继承关系&#xff1a; public class Observ…

SQL 多表查询

文章目录 多表查询的分类等值连接非等值连接自连接非自连接内连接外连接左外连接右外连接满外连接 SQL连接 JOINSQL99 语法新特性 自然连接 NATURAL JOIN & USING 多表查询的分类 等值连接 VS 非等值连接自连接 VS 非自连接内连接 VS 外连接 等值连接 关联的表有连接字段…

2.4_4 死锁的检测和解除

文章目录 2.4_4 死锁的检测和解除&#xff08;一&#xff09;死锁的检测&#xff08;二&#xff09;死锁的解除 总结 2.4_4 死锁的检测和解除 如果系统中既不采取预防死锁的措施&#xff0c;也不采取避免死锁的措施&#xff0c;系统就很可能发生死锁。在这种情况下&#xff0c;…