【深度学习--RNN 循环神经网络--附LSTM情感文本分类】

news2024/12/28 18:02:02

deep learning 系列 --RNN 循环神经网络

什么是序列模型

包括了RNN LSTM GRU等网络模型,主要用途是自然语言处理、语音识别等方面,比如生成乐曲,音频转换为文字,文本情感分类,机器翻译等等

标准模型的缺陷

以往的标准模型比如CNN,每次的输入不影响下次的输出,也就是说每次输入的图片都是独立的,没有任何关联,但是很多情况下,我们建立的模型与前项甚至后项的输入是相关的。举个例子,我们要从这两句话中识别出人名:

President Teddy was …
Teddy bear was…

这其中都有一个关键且同样的词Teddy,但是这个Teddy可以是个人名,也可以是泰迪,究竟是哪个呢?通常的识别方法是:

1.从之前学习过的人名识别

我们之前的位置可能提取过Teddy是人名, 但CNN网络,并不能共享从不同位置提取到到的特征,因此不可行

2.从本次输入的上下文出发

比如这里下文有 bear,President,但是CNN也不具备这种序列特性,因此也不可行

另外,CNN的网络对于输入的模型的数据长度都是固定的,但是不同的句子长度并不一致,当然我们可以padding,但是不那么好。
因此才有了循环神经网络架构,它可以克服以上的这些问题

循环神经网络

架构

在这里插入图片描述
图示是循环神经网络架构, 是有循环的,这个图形是最长用于展示的,但实际并不好理解,这个环实际上是多次的输入和输出,下面这种展开的方式更容易理解
在这里插入图片描述
RNN的结构就包括三类输入的权重参数
1.激活值的参数,水平方向的值,每个时间步的激活参数是相同的
2.输入到隐藏层的参数,也就是x到A方向,这个也是每个时间步相同
3.用来预测输出的参数,A到h的方向
在这里插入图片描述
其中y的输出可以用如下表示:
在这里插入图片描述

在这里插入图片描述
预测值y<t>包括了激活值a<t>,而a<t>包括了x<t-1>(注意此处t表示时间步),也就是说每一次的预测输出不仅包括了本次x<t>也包括了上个时间步x<t-1>,依次类推,此次时间步的输出包括了之前所有输入

BPTT 反向传播

实际上在工具中,反向传播都是自动进行的,原理跟普通的反向传播一致
输入序列后,假设预测的值是0.9,但实际是1,产生损失,这个损失可以用交叉熵损失函数来估量
RNN中的反向传播时将之前所有的箭头都反过来,计算出合适的变量,通过导数相关的计算,利用梯度下降算法更新参数,也就是图示:
在这里插入图片描述
这个传播过程中最重要的就是水平方向时间步的反向,因此又叫做穿越时间的反向传播Back Propagate Through time

不同类型的循环神经网络

  • 多对多
    比如机器翻译,输入是多个,输出也是多个,且并不对等,此时经常使用的是encode-decoder,encoder编码器获取输入,并输出,而decoder则使用encode编码的输出,执行decoder,这样输入x和输出y的长度就可以不相同了
    在这里插入图片描述
    当然输入和输出对等的情况也很多
    比如在句子中寻找人名,预测输出y就可以是每个x的对应的每个位置输出y(0表示非人名,1表示是人名)

  • 多对一
    在这里插入图片描述
    比如进行文本的情感分类,这样输入就有可能是一段文本,而输出我们只需要最后一层最后一个时间步的输出即可,一个0和1就足以标识这段文本是positive还是negative

  • 一对多
    在这里插入图片描述
    比如生成类的,输入一个音符或者不输入,就可以产生多个输出

RNN的缺陷 梯度消失和梯度爆炸

  • 现象:

    实际上深度比较大的网络都可能梯度消失或者爆炸,这种现象在在RNN中更加明显
    当我们输入的序列为1000的时候,拿最简单的模型举例 y = wx 经过1000次的传播,y1000 的变化
    在这里插入图片描述
    w仅仅变化一点,经过1000次的传播,变化非常的大

  • 原因:

    发生这样的根本原因是RNN中每一次的输出都将被前面的数据彻底的清洗,而经过长时间的传输,很前面步的影响都后面的影响已经很微弱了,损失的反向调整同样也是,经过长距离的调整,差错已经很难反馈很多个时间步之后了

  • 解决办法:

    梯度爆炸:进行梯度裁剪即可,比如我们发现输出有很多超大的值的时候,进行裁剪
    梯度消失:它很难察觉,也在标准的RNN结构中,可以用GRU或者LSTM解决,而解决梯度消失的实质是通过保留一些前期输入的记忆

LSTM

标准的RNN结构是这样的:
在这里插入图片描述
而标准的LSTM加入了四个门控单元
在这里插入图片描述
这四个门单元分别是:
it, ft,gt, ot 分别是输入门、遗忘门、cell(记忆)门,以及输出门
他们控制 是否输入,对输入的遗忘和记忆,也控制是否输出,从而控制重要信息的传递,不重要信息遗忘

GRU

它比LSTM诞生更晚,是LSTM的变形版本,由于门单元更少,计算简单些,因此训练时间更短一些。
在这里插入图片描述

LSTM实例

本实例以IMDB数据集为例,代码篇幅过长,本文仅列示其中LSTM使用相关的重点,后续会有专门的博客详细解析代码。

  1. 预处理数据
    读取IMDB数据集
def read_imdb(datafolder ='train', dataroot=imdb_zip_path):
    data=[]
    for label in ['pos', 'neg']:
        filepath = os.path.join(imdb_zip_path, datafolder, label)
        for file in tqdm(os.listdir(filepath)):
            with open(os.path.join(filepath,file), 'rb') as f:
                content = f.read().decode('utf-8').replace('\n', ' ').lower()
                data.append([content, 1 if label == 'pos' else 0])
    random.shuffle(data)
    return data

IMDB中的数据分词

def get_tokenized(data):
    def tokenizer(text):
        filters = ['!', '"', '#', '$', '%', '&', '\(', '\)', '\*', '\+', ',', '-', '\.', '/', ':', ';', '<', '=', '>',
                    '\?', '@', '\[', '\\', '\]', '^', '_', '`', '\{', '\|', '\}', '~', '\t', '\n', '\x97', '\x96', '”', '“', ]
        text = re.sub("<.*?>", " ", text, flags=re.S)
        text = re.sub("|".join(filters), " ", text, flags=re.S)
        return [i.strip().lower() for i in text.split()]
    return [tokenizer(context) for context, _ in data]

创建分词后的词典

def get_vocab(data):
    counter = collections.Counter(_flatten(data))
    return vocab.vocab(counter)

封装dataloader

class ImdbLoader(object):
    def __init__(self, set_name='train', batch_size='64'):
        super(ImdbLoader, self).__init__()
        self.data_set = set_name
        self.batch_size = batch_size

    def get_data_loader(self):
        # train_data = [['"dick tracy" is one of our"', 1],
        #               ['arguably this is a  the )', 1],
        #               ["i don't  just to warn anyone ", 0]]
        train_data = read_imdb(self.data_set)
        data = preprocess(train_data)
        #print(data)
        data_set = Data.TensorDataset(*data)
        data_loader = Data.DataLoader(data_set, self.batch_size, shuffle=True)
        return data_loader
  1. 创建模型
    此处创建了一个模型,包括双向的LSTM层和一个全连接层
 class BiRNN(nn.Module):
    def __init__(self, vocabulary, embed_len, hidden_len, num_layer):
        super(BiRNN, self).__init__()
        self.embedding = nn.Embedding(len(vocabulary), embed_len)
        self.encoder = nn.LSTM(input_size=embed_len,
                               hidden_size=hidden_len,
                               num_layers=num_layer,
                               bidirectional=True,
                               dropout = 0.3)

        # 本次使用起始和最终时间步的隐藏状态座位全连接层的输入
        self.decoder = nn.Linear(2*2*hidden_len, 2)

    def forward(self, inputs):
        #print('rnn model py: input_shape: ', inputs.shape)
        embeddings = self.embedding(inputs)

        glove_vab = getGlove()
        net.embedding.weight.data.copy_(load_pretrained_embedding(vo.get_itos(), glove_vab))
        net.embedding.weight.requires_grad = False

        #print('after embed input shape:', embeddings.shape)
        embeddings = embeddings.permute(1, 0, 2)
        output_sequence, _ = self.encoder(embeddings)
        concat_out = torch.cat((output_sequence[0], output_sequence[-1]), -1)
        outputs = self.decoder(concat_out)
        return outputs

3.模型训练

   def train(epoch, imdb_model, lr, train_batch_size):
    imdb_model_device = imdb_model.to(device)
    # 过滤掉不需要计算梯度的embedding的参数
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, imdb_model_device.parameters()), lr=lr)
    loader = ImdbLoader('train', train_batch_size)
    data_loader = loader.get_data_loader()

    for i in range(epoch):
        for idx, (inputs, target) in enumerate(data_loader):
            target = target.to(device)
            inputs = inputs.to(device)
            #print('train.py input shape:', inputs.shape)

            optimizer.zero_grad()
            output = imdb_model(inputs)
            #print('ouput.shape', output.shape)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if idx % 10 == 0:
                predict = torch.max(output, dim=-1, keepdim=False)[-1]
                acc = predict.eq(target.data).cpu().numpy().mean() * 100
                print('train Epoch:{} processed:[{} / {} ({:.0f}%) Loss: {:.6f}, ACC: {:.6f}]'.format(
                                                                          i,
                                                                          idx * len(inputs),
                                                                          len(data_loader.dataset),
                                                                          100. * idx / len(data_loader),
                                                                          loss.item(),
                                                                          acc))
        torch.save(imdb_model.state_dict(), '../../resources/model_save/imdb_net.pkl')
        torch.save(optimizer.state_dict(), '../../resources/model_save/imdb_optimizer.pkl')

4.模型评估

 def test(imdb_model, test_batch_size):
    imdb_model.eval()
    imdb_model = imdb_model.to(device)
    loader = ImdbLoader('test', test_batch_size)
    data_loader = loader.get_data_loader()
    with torch.no_grad():
        for idx, (inputs, target) in enumerate(data_loader):
            target = target.to(device)
            inputs = inputs.to(device)
            #print(inputs)
            output = imdb_model(inputs)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(output, target)

            predict = torch.max(output, dim=-1, keepdim=False)[-1]
            correct = predict.eq(target.data).sum()
            acc = 100. * predict.eq(target.data).cpu().numpy().mean()
            print('idx: {} loss : {}, accurate: {}/{} {:.2f}'.format(idx,  loss, correct, target.size(0), acc))

最终效果,当我们执行4个epoch后,准确率基本稳定在80%以上

train Epoch:3 processed:[23680 / 25000 (95%) Loss: 0.325240, ACC: 84.375000]
train Epoch:3 processed:[24320 / 25000 (97%) Loss: 0.449456, ACC: 75.000000]
train Epoch:3 processed:[15600 / 25000 (100%) Loss: 0.438567, ACC: 80.000000]
train Epoch:4 processed:[0 / 25000 (0%) Loss: 0.353131, ACC: 85.937500]
train Epoch:4 processed:[640 / 25000 (3%) Loss: 0.345814, ACC: 89.062500]
train Epoch:4 processed:[1280 / 25000 (5%) Loss: 0.195520, ACC: 93.750000]
train Epoch:4 processed:[1920 / 25000 (8%) Loss: 0.269773, ACC: 87.500000]
train Epoch:4 processed:[2560 / 25000 (10%) Loss: 0.287010, ACC: 85.937500]
train Epoch:4 processed:[3200 / 25000 (13%) Loss: 0.291449, ACC: 90.625000]

参考

colah https://colah.github.io/posts/2015-08-Understanding-LSTMs/
吴恩达 deep learning

https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/886977.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows上使用FFmpeg实现本地视频推送模拟海康协议rtsp视频流

场景 Nginx搭建RTMP服务器FFmpeg实现海康威视摄像头预览&#xff1a; Nginx搭建RTMP服务器FFmpeg实现海康威视摄像头预览_nginx rtmp 海康摄像头_霸道流氓气质的博客-CSDN博客 上面记录的是使用FFmpeg拉取海康协议摄像头的rtsp流并推流到流媒体服务器。 如果在其它业务场景…

ppt中线材相交接的地方,如何绘画

ppt中线材相交接的地方&#xff1a; 在ppt中绘画线材相互交接的地方&#xff1a; 1.1绘图工具中的“弧形” 1.2小技巧 “弧形”工具点一下&#xff0c;在ppt中如下 1.3拖动活动点进行调整图形 1.4绘画圆弧 1.5调整“圆弧”的大小&#xff0c;鼠标放在“黄色点”位置&#xf…

爬虫逆向实战(十七)--某某丁简历登录

一、数据接口分析 主页地址&#xff1a;某某丁简历 1、抓包 通过抓包可以发现数据接口是submit 2、判断是否有加密参数 请求参数是否加密&#xff1f; 通过查看“载荷”模块可以发现有一个enPassword加密参数 请求头是否加密&#xff1f; 通过查看请求头可以发现有一个To…

React 高阶组件(HOC)

React 高阶组件(HOC) 高阶组件不是 React API 的一部分&#xff0c;而是一种用来复用组件逻辑而衍生出来的一种技术。 什么是高阶组件 高阶组件就是一个函数&#xff0c;且该函数接受一个组件作为参数&#xff0c;并返回一个新的组件。基本上&#xff0c;这是从 React 的组成…

部署MES管理系统首先要解决什么问题

随着制造业市场竞争的加剧&#xff0c;企业需要更加高效、灵活的生产运营&#xff0c;以提高产品质量和降低成本。在这种情况下&#xff0c;MES管理系统解决方案成为许多企业的选择。然而&#xff0c;在部署MES管理系统之前&#xff0c;必须首先解决一些关键问题&#xff0c;以…

ReactNative进阶(三十四):ipa Archive 阶段报错error: Multiple commands produce问题修复及思考

文章目录 一、前言二、问题描述三、问题解决四、拓展阅读五、拓展阅读 一、前言 在应用RN开发跨平台APP阶段&#xff0c;从git中拉取项目&#xff0c;应用Jenkins进行组包时&#xff0c;发现最终生成的ipa安装包版本号始终与项目中设置的版本号不一致。 二、问题描述 经过仔…

常见排序集锦-C语言实现数据结构

目录 排序的概念 常见排序集锦 1.直接插入排序 2.希尔排序 3.选择排序 4.堆排序 5.冒泡排序 6.快速排序 hoare 挖坑法 前后指针法 非递归 7.归并排序 非递归 排序实现接口 算法复杂度与稳定性分析 排序的概念 排序 &#xff1a;所谓排序&#xff0c;就是使一串记录&#…

PHP 报修管理系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 报修管理系统 是一套完善的web设计系统&#xff0c;对理解php编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 下载地址&#xff1a; https://download.csdn.net/download/qq_41221322/88209950 视…

一次网络不通 “争吵” 引发的思考

为啥争吵&#xff0c;吵什么&#xff1f; "你到底在说什么啊&#xff0c;我 K8s 的 ecs 节点要访问 clb 的地址不通和本地网卡有什么关系..." 气愤语气都从电话那头传了过来&#xff0c;这时电话两端都沉默了。过了好一会传来地铁小姐姐甜美的播报声打断了刚刚的沉寂…

【校招VIP】java语言考点之List和扩容

考点介绍&#xff1a; List是最基础的考点&#xff0c;但是很多同学拿不到满分。本专题从两种实现子类的比较&#xff0c;到比较复杂的数组扩容进行分析。 『java语言考点之List和扩容』相关题目及解析内容可点击文章末尾链接查看&#xff01;一、考点题目 1、以下关于集合类…

仪表板展示 | DataEase看中国:2023年中国电影市场分析

背景介绍 随着《消失的她》、《变形金刚&#xff1a;超能勇士崛起》、《蜘蛛侠&#xff1a;纵横宇宙》、《我爱你》等国内外影片的上映&#xff0c;2023年上半年的电影市场也接近尾声。据国家电影专资办初步统计&#xff0c;上半年全国城市院线票房达262亿元&#xff0c;已经超…

气象监测站:用科技感知气象变化

气象监测站是利用科学技术感知当地小气候变化情况的气象观测仪器&#xff0c;可用于农业、林业、养殖业、畜牧业、环境保护、工业等多个领域&#xff0c;提高对环境数据的利用率&#xff0c;促进产业效能不断提升。 气象监测站主要由气象传感器、数据传输系统、电源系统、支架…

Python web实战之Django 的跨站点请求伪造(CSRF)保护详解

关键词&#xff1a;Python、Web、Django、跨站请求伪造、CSRF 大家好&#xff0c;今天我将分享web关于安全的话题&#xff1a;Django 的跨站点请求伪造&#xff08;CSRF&#xff09;保护&#xff0c;介绍 CSRF 的概念、原理和保护方法. 1. CSRF 是什么&#xff1f; CSRF&#…

字符设备驱动实例(LED、按键、input输入子系统)

目录 本章目标 一、LED驱动 二、基于中断的简单按键驱动 三、基于输入子系统的按键驱动 本章目标 本章综合前面的知识&#xff0c;实现了嵌入式系统的常见外设驱动&#xff0c;包括 LED、按键、ADC、PWM和RTC。本章从工程的角度、实用的角度探讨了某些驱动的实现。比如LED …

ZooKeeper的应用场景(命名服务、分布式协调通知)

3 命名服务 命名服务(NameService)也是分布式系统中比较常见的一类场景&#xff0c;在《Java网络高级编程》一书中提到&#xff0c;命名服务是分布式系统最基本的公共服务之一。在分布式系统中&#xff0c;被命名的实体通常可以是集群中的机器、提供的服务地址或远程对象等一这…

ListNode相关

目录 2. 链表相关题目 2.1 合并两个有序链表&#xff08;简单&#xff09;&#xff1a;递归 2.2 删除排序链表中的重复元素&#xff08;简单&#xff09;&#xff1a;一次遍历 2.3 两链表相加&#xff08;中等&#xff09;&#xff1a;递归 2.4 删除链表倒数第N个节点&…

数据结构:栈和队列(超详细)

目录 ​编辑 栈&#xff1a; 栈的概念及结构&#xff1a; 栈的实现&#xff1a; 队列&#xff1a; 队列的概念及结构&#xff1a; 队列的实现&#xff1a; 扩展知识&#xff1a; 以上就是个人学习线性表的个人见解和学习的解析&#xff0c;欢迎各位大佬在评论区探讨&#…

取证的学习

Volatility命令语法 1.判断镜像信息&#xff0c;获取操作系统类型 Volatility -f xxx.vmem imageinfo 在查到操作系统后如果不确定可以使用以下命令查看 volatility - f xxx.vmem --profile [操作系统] volshell 2.知道操作系统类型后&#xff0c;用–profile指定 volat…

redis查看执行的命令

1.SLOWLOG LEN 获取 Slowlog 的长度&#xff0c;以确定 Slowlog 中有多少条记录 2.SLOWLOG GET 获取 Slowlog 中的具体记录。你可以使用 SLOWLOG GET 命令来获取第 n 条记录的详细信息&#xff0c;其中 n 是记录的索引&#xff08;从 0 开始&#xff09; 3.如果你想获取多条最…

68 # 中间层如何请求其他服务

前端 ajax 有跨域问题&#xff0c;可以先访问中间层&#xff0c;在通过 node 去请求别的服务端口&#xff0c;可以解决跨域问题 编写中间层调用 // 中间层的方式const http require("http");// http.get 默认发送 get 请求 // http.request 支持其他请求格式 postl…