【深度学习入门篇 ⑨】循环神经网络实战

news2024/9/23 11:25:08

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


今天我们看一下用循环神经网络RNN的原理并且动手应用到案例。

3e012755cfd647aebdf70ff24536d38b.png 

循环神经网络

在普通的神经网络中,信息的传递是单向的,这种限制虽然使得网络变得更容易学习,但在一定程度上也减弱了神经网络模型的能力。特别是在很多现实任务中,网络的输出不仅和当前时刻的输入相关,也和其过去一段时间的输出相关。此外,普通网络难以处理时序数据,比如视频、语音、文本等,时序数据的长度一般是不固定的,而前馈神经网络要求输入和输出的维数都是固定的,不能任意改变。因此,当处理这一类和时序相关的问题时,就需要一种能力更强的模型。

循环神经网络 (RNN)是一类具有短期记忆能力的神经网络。在循环神经网络中,神经元不但可以接受其它神经元的信息,也可以接受自身的信息,形成具有环路的网络结构。  

ab119b30479c4d74bb10bf02ef0d9f34.png 

RNN比传统的神经网络多了一个循环圈,这个循环表示的就是在下一个时间步上会返回作为输入的一部分,我们把RNN在时间点上展开 :

6e2096802ad346c1836d1ede9370a9fe.png

在不同的时间步,RNN的输入都将与之前的时间状态有关 ,具体来说,每个时间步的RNN单元都会接收两个输入:当前时间步的外部输入和前一时间步(隐藏层)的输出状态。通过这种方式,RNN能够学习并理解数据中的长期依赖关系,使得它在处理文本生成、语音识别、时间序列预测等序列数据时表现尤为出色。

此外,RNN的隐藏状态(或称为内部状态)在每次迭代时都会更新,这种更新过程包含了当前输入和前一时间步状态的非线性组合,使得网络能够动态地调整其对序列中接下来内容的预测或理解。

d1ad2acff14b48458791021e8ce8eaa5.png

LSTM和GRU

传统的RNN在处理长序列数据时常常面临梯度消失或梯度爆炸的问题,这限制了其在处理长期依赖关系上的能力。为了克服这一局限性,LSTM(Long Short-Term Memory,长短期记忆网络)作为RNN的一种变体被引入。

LSTM是一种RNN特殊的类型,可以学习长期依赖信息。在很多问题上,LSTM都取得相当巨大的成功,并得到了广泛的应用。

48465d18371741739f23324e0f1f3e05.png

LSTM是通过一个叫做的结构实现,门可以选择让信息通过或者不通过。 这个门主要是通过sigmoid和点乘实现的 ;sigmoid 的取值范围是在(0,1)之间,如果接近0表示不让任何信息通过,如果接近1表示所有的信息都会通过。

  • 遗忘门通过sigmoid函数来决定哪些信息会被遗忘
  • 输入门决定哪些新的信息会被保留。

例如:

我昨天吃了拉面,今天我想吃炒饭,在这个句子中,通过遗忘门可以遗忘拉面,同时更新新的主语为炒饭。

输出门

我们需要决定什么信息会被输出,也是一样这个输出经过变换之后会通过sigmoid函数的结果来决定那些细胞状态会被输出。

  1. 前一次的输出和当前时间步的输入的组合结果通过sigmoid函数进行处理得到O_t

  2. 更新后的细胞状态C_t会经过tanh层的处理,把数据转化到(-1,1)的区间

  3. tanh处理后的结果和O_t进行相乘,把结果输出同时传到下一个LSTM的单元

8ca0b205bcfa44e18c3af5b4f7271880.png 

GRU

GRU是一种LSTM的变形版本, 它将遗忘和输入门组合成一个“更新门”。它还合并了单元状态和隐藏状态,并进行了一些其他更改,由于他的模型比标准LSTM模型简单,所以越来越受欢迎。

664e50357e604f918c707643ca15bc9c.png

b429639b6a994ec099f87d8adf609263.png 

双向LSTM

单向的 RNN,是根据前面的信息推出后面的,但有时候只看前面的词是不够的, 可能需要预测的词语和后面的内容也相关,那么此时需要一种机制,能够让模型不仅能够从前往后的具有记忆,还需要从后往前需要记忆。此时双向LSTM就可以帮助我们解决这个问题

f990226c2e3a4c9da262cc74ff2201e4.png 

由于是双向LSTM,所以每个方向的LSTM都会有一个输出,最终的输出会有2部分,所以往往需要concat的操作。

96f81f98d8e74dadaa1f4925a3406007.pngRNN实现文本情感分类 

torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first,dropout,bidirectional)
  1. input_size:输入数据的形状,即embedding_dim

  2. hidden_size:隐藏层神经元的数量,即每一层有多少个LSTM单元

  3. num_layer :即RNN的中LSTM单元的层数

  4. batch_first:默认值为False,输入的数据需要[seq_len,batch,feature],如果为True,则为[batch,seq_len,feature]

  5. dropout:dropout的比例,默认值为0。dropout是一种训练过程中让部分参数随机失活的一种方式,能够提高训练速度,同时能够解决过拟合的问题。

  6. bidirectional:是否使用双向LSTM,默认是False

实例化LSTM对象之后,不仅需要传入数据,还需要前一次的h_0(前一次的隐藏状态)和c_0

LSTM的默认输出为output, (h_n, c_n)  

  1. output(seq_len, batch, num_directions * hidden_size)--->batch_first=False

  2. h_n:(num_layers * num_directions, batch, hidden_size)

  3. c_n: (num_layers * num_directions, batch, hidden_size)

 4b9843ea2e35484f86a90641afd0fff6.png

LSTM和GRU的使用注意点

  1. 第一次调用之前,需要初始化隐藏状态,如果不初始化,默认创建全为0的隐藏状态

  2. 往往会使用LSTM or GRU 的输出的最后一维的结果,来代表LSTM、GRU对文本处理的结果,其形状为[batch, num_directions*hidden_size]

使用LSTM完成文本情感分类

class IMDBLstmmodel(nn.Module):
    def __init__(self):
        super(IMDBLstmmodel,self).__init__()
        self.hidden_size = 64
        self.embedding_dim = 200
        self.num_layer = 2
        self.bidriectional = True
        self.bi_num = 2 if self.bidriectional else 1
        self.dropout = 0.5


        self.embedding = nn.Embedding(len(ws),self.embedding_dim,padding_idx=ws.PAD) #[N,300]
        self.lstm = nn.LSTM(self.embedding_dim,self.hidden_size,self.num_layer,bidirectional=True,dropout=self.dropout)

        self.fc = nn.Linear(self.hidden_size*self.bi_num,20)
        self.fc2 = nn.Linear(20,2)


    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(1,0,2) 
        h_0,c_0 = self.init_hidden_state(x.size(1))
        _,(h_n,c_n) = self.lstm(x,(h_0,c_0))

        out = torch.cat([h_n[-2, :, :], h_n[-1, :, :]], dim=-1)
        out = self.fc(out)
        out = F.relu(out)
        out = self.fc2(out)
        return F.log_softmax(out,dim=-1)

    def init_hidden_state(self,batch_size):
        h_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)
        c_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)
        return h_0,c_0

为了提高程序的运行速度,可以考虑把模型放在GPU上运行:

  1. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  2. model.to(device)

train_batch_size = 64
test_batch_size = 5000
imdb_model = IMDBLstmmodel().to(device) 
optimizer = optim.Adam(imdb_model.parameters())
criterion = nn.CrossEntropyLoss()

def train(epoch):
    mode = True
    imdb_model.train(mode)
    train_dataloader =get_dataloader(mode,train_batch_size)
    for idx,(target,input,input_lenght) in enumerate(train_dataloader):
        target = target.to(device)
        input = input.to(device)
        optimizer.zero_grad()
        output = imdb_model(input)
        loss = F.nll_loss(output,target) 
        loss.backward()
        optimizer.step()
        if idx %10 == 0:
            pred = torch.max(output, dim=-1, keepdim=False)[-1]
            acc = pred.eq(target.data).cpu().numpy().mean()*100.

            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t ACC: {:.6f}'.format(epoch, idx * len(input), len(train_dataloader.dataset),
                       100. * idx / len(train_dataloader), loss.item(),acc))

            torch.save(imdb_model.state_dict(), "model/mnist_net.pkl")
            torch.save(optimizer.state_dict(), 'model/mnist_optimizer.pkl')
            
 def test():
    mode = False
    imdb_model.eval()
    test_dataloader = get_dataloader(mode, test_batch_size)
    with torch.no_grad():
        for idx,(target, input, input_lenght) in enumerate(test_dataloader):
            target = target.to(device)
            input = input.to(device)
            output = imdb_model(input)
            test_loss  = F.nll_loss(output, target,reduction="mean")
            pred = torch.max(output,dim=-1,keepdim=False)[-1]
            correct = pred.eq(target.data).sum()
            acc = 100. * pred.eq(target.data).cpu().numpy().mean()
            print('idx: {} Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(idx,test_loss, correct, target.size(0),acc))
            
 if __name__ == "__main__":
    test()
    for i in range(10):
        train(i)
        test()

然后由大家写代码得到模型训练的最终输出,大家可以改变模型来观察不同的结果。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1934073.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

把当前img作为到爷爷的背景图

(忽略图大小不一致,一般UI给的图会刚好适合页面大小,我这网上找的图,难调大小,我行内的就自己随便写的宽高),另外悄悄告诉你最后有简单方法~~ 先来看看初始DOM结构代码 …

【接口自动化_12课_基于Flask搭建MockServer】

知识非核心点,面试题较少。框架搭建的过程中的细节才是面试要点 第三方接口,不方便进行测试, 自己要一个接口去进行模拟。去作为我们项目访问模拟接口。自己写一个接口,需要怎样写 一、flask:轻量级的web应用的框架 安装命令 pip install flask 1、flask-web应用 1)…

【防雷】浪涌保护器的选择与应用

浪涌保护器(SPD)是一种用于保护电气设备免受电力系统突发的电压浪涌或过电压等干扰的重要装置。供电系统由于外部受雷击、过电压影响,内部受大容量设备和变频设备的开、关、重启、短路故障等,都会产生瞬态过电压,带来日…

你下载的蓝光电影,为什么不那么清晰?

1080P 为什么糊 蓝光对应的就是 1080P分辨率为 1920 * 1080 随便抽取一帧画面,得到的就是一张有 1920 * 1080 个像素点的图片大多数电影是每秒播放 24 张图片,也就是一个 24 帧的电影 电影在电脑上的储存 压缩方案 不仅仅有如下两种,还有…

Vue3 + uni-app 微信小程序:仿知乎日报详情页设计及实现

引言 在移动互联网时代,信息的获取变得越来越便捷,而知乎日报作为一款高质量内容聚合平台,深受广大用户喜爱。本文将详细介绍如何利用Vue 3框架结合微信小程序的特性,设计并实现一个功能完备、界面美观的知乎日报详情页。我们将从…

Linux LVM扩容方法

问题描述 VMware Centos环境,根分区为LVM,大小50G,现在需要对根分区扩容。我添加了一块500G的虚拟硬盘(/dev/sdb),如何把这500G扩容到根分区? LVM扩容方法 1. 对新磁盘分区 使用fdisk /dev/sdb命令,进…

C++:类和对象1

1.类的定义 类定义在面向对象编程中是一个核心概念,它定义了对象的结构和行为。在C中,类定义包含类的名称、数据成员(也称为属性或者字段)和成员函数(也称为方法或者操作)多个部分。数据成员定义了对象的状…

2024-07-16 Unity插件 Odin Inspector5 —— Conditional Attributes

文章目录 1 说明2 条件特性2.1 DisableIf / EnableIf2.2 DisableIn / EnableIn / ShowIn / HideIn2.3 DisableInEditorMode / HideInEditorMode2.4 DisableInInlineEditors / ShowInInlineEditors / HideInInlineEditors2.5 DisableInPlayMode / HideInPlayMode2.6 ShowIf / Hi…

docker安装mysql突然无法远程连接

docker安装mysql突然莫名其妙的无法远程连接 docker安装mysql突然无法远程访问问题背景发现问题排查问题解决问题总结 docker安装mysql突然无法远程访问 问题背景 大概一年前在服务器中通过docker安装mysql5.7端口映射关系是3308->3306 前期在服务器上开方了3308端口 fir…

Python用Pyqt5制作音乐播放器

具体效果如下 需要实现的功能主要的几个有: 1、搜索结果更新至当前音乐的列表,这样播放下一首是搜素结果的下一首 2、自动播放 3、滚动音乐文本 4、音乐进度条 5、根据实际情况生成音乐列表。我这里的是下面的情况,音乐文件的格式是 歌…

图——图的遍历(DFS与BFS算法详解)

前面的文章中我们学习了图的基本概念和存储结构,大家可以通过下面的链接学习: 图的定义和基本术语 图的类型定义和存储结构 这篇文章就来学习一下图的重要章节——图的遍历。 目录 一,图的遍历定义: 二,深度优先…

【MySQL】:学习数据库必须要知道的背景知识

客户端—服务器 客户端是一个“客户端—服务器”结构的程序 C(client)—S(server) 客户端和服务器是两个独立的程序,这两个程序之间通过“网络”进行通信(相当于是两种角色) 客户端 主动发起网…

CV12_ONNX转RKNN模型(谛听盒子)

暂时简单整理一下: 1.在边缘设备上配置相关环境。 2.配置完成后,获取模型中间的输入输出结果,保存为npy格式。 3.将onnx格式的模型,以及中间输入输出文件传送到边缘设备上。 4.编写一个python文件用于转换模型格式&#xff0c…

对某根域的一次渗透测试

前言 两个月之前的一个渗透测试项目是基于某网站根域进行渗透测试,发现该项目其实挺好搞的,就纯粹的没有任何防御措施与安全意识所以该项目完成的挺快,但是并没有完成的很好,因为有好几处文件上传没有绕过(虽然从一个…

linux|多线程(一)

主要介绍了为什么要有线程 和线程的调用 和简单的对线程进行封装。 背景知识 a.重谈地址空间 我们知道物理内存的最小单元大小是4kB 物理内存是4G那么这样的单元友1M个 操作系统先描述再组织struct page[1M] 对于32位数据字长的机器,页表有2^32条也就是4G条&#…

springboot的JWT令牌

生成JWT令牌 依赖 <!--jwt令牌--> <dependency> <groupId>io.jsonwebtoken</groupId> <artifactId>jjwt</artifactId> <version>0.9.1</version> </dependency> <dependency> <groupId>javax.xml.bind<…

怎样在 PostgreSQL 中优化对大数据量的分页查询?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01;&#x1f4da;领书&#xff1a;PostgreSQL 入门到精通.pdf 文章目录 《PostgreSQL 中大数据量分页查询的优化之道》一、理解分页查询的基本原理二、优化分页查询的策略&…

2024年06月CCF-GESP编程能力等级认证C++编程七级真题解析

本文收录于专栏《C等级认证CCF-GESP真题解析》&#xff0c;专栏总目录&#xff1a;点这里。订阅后可阅读专栏内所有文章。 一、单选题&#xff08;每题 2 分&#xff0c;共 30 分&#xff09; 第 1 题 下列C代码的输出结果是&#xff08; &#xff09;。 #include <iostr…

SwiftUI 6.0(Xcode 16)新 PreviewModifier 协议让预览调试如虎添翼

概览 用 SwiftUI 框架开发过应用的小伙伴们都知道&#xff0c;SwiftUI 中的视图由各种属性和绑定“扑朔迷离”的缠绕在一起&#xff0c;自成体系。 想要在 Xcode 预览中泰然处之的调试 SwiftUI 视图有时并不是件容易的事。其中&#xff0c;最让人秃头码农们头疼的恐怕就要数如…

Spring Cloud Gateway 自定义断言以及过滤器

1.Spring Cloud gateway介绍 Spring Cloud Gateway 是一个基于 Spring Framework 和 Spring Boot 的 API 网关服务&#xff0c;它利用了 Spring WebFlux 来提供响应式非阻塞式Web请求处理能力。它的核心功能是路由&#xff0c;即根据请求的特定规则将请求转发到后端服务&#…