经典网络 循环神经网络(一) | RNN结构解析,代码实现

news2024/12/29 10:50:45

文章目录

  • 1 提出背景
  • 2 RNN
    • 2.1 RNN结构
    • 2.2 RNN代码实现
    • 2.3 代码简洁实现

1 提出背景

为什么要引入RNN呢?

非常简单,之前我们的卷积神经网络CNN,全连接神经网络等都是单个神经元计算

但在序列模型中,前一个神经元往往对后面一个神经元有影响

比如

两句话

I like eating apples.

I want to have a apple watch

第一个苹果和第二个苹果的概念是不一样的,第一个苹果是红彤彤的苹果,第二个苹果是苹果公司的意思

如何知道

是因为apple的翻译参考了上下文,第一句话看到了eating这个单词,第二句话看到了watch这个单词

因而可见,对于语言这种时序信息,利用需要参考上下文进行

还有其他原因

  • 拿人类的某句话来说,也就是人类的自然语言,是不是符合某个逻辑或规则的字词拼凑排列起来的,这就是符合序列特性。
  • 语音,我们发出的声音,每一帧每一帧的衔接起来,才凑成了我们听到的话,这也具有序列特性、
  • 股票,随着时间的推移,会产生具有顺序的一系列数字,这些数字也是具有序列特性。

2 RNN

具有时序功能,从某种意义来说,RNN也就具有了记忆功能,好比我们人类自己,为什么会受到过去影响,因为我们具有记忆能力。

同时只有记忆能力是不够的,处理后的信息得储存起来,形成“新的记忆”

对于RNN,可以分为单向RNN,和双向RNN,其中单向的是只利用前面的信息,而双向的RNN既可以利用前面的信息,也可以利用后面的信息。

2.1 RNN结构

RNN的基本单元包含以下关键组件:

  • 输入 ( x t x_t xt ): 表示在时间步 (t) 的输入序列。
  • 隐藏状态 ( h t h_t ht ): 在时间步 (t) 的隐藏状态,是网络在处理序列过程中保留的信息相当于ht里面藏着上下文信息
  • 每一步的输出(Oi):每一个时间步有一个输出Oi,Oi综合了当前时间步和之前的很多信息,那么对于某些特定任务,如分类什么的,就可以直接用Oi去做判断。很多时候直接把隐藏状态拿去做了输出

如下图,图片来自《动手学深度学习》

在这里插入图片描述

那么每一个隐状态是通过怎样的方式得到的呢?

RNN的隐藏状态 (ht ) 的计算通过以下数学公式完成:

$h_t=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{t−1}+b_{hh}) $

这个公式展示了RNN如何根据当前输入 (xt ) 和前一个时间步的隐藏状态 (ht−1 ) 来计算当前时间步的隐藏状态 (ht )。其中 (tanh) 是双曲正切激活函数,用于引入非线性。

实际中我们可以看到

  • 权重矩阵 ($W_{ih} , W_{hh} $): 分别是输入到隐藏状态和隐藏状态到隐藏状态的权重矩阵。
  • 偏差 ($b_{ih} , b_{hh} $): 对应的偏差。

在这里插入图片描述

第一个问题 是每一个句子的长度不一致,你怎么用统一的矩阵呢?

​ 只实现了一个单层神经元,可以通过获得句子长度知道时间步数t,进一步做相关的调整

2.2 RNN代码实现

代码实现首先实现上图的一个神经元

def rnn(inputs, state, params):
    # inputs的形状:(时间步数量,批量大小,词表大小)
    W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    # X的形状:(批量大小,词表大小)
    for X in inputs:
        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

class RNNModelScratch: #@save
    """从零开始实现的循环神经网络模型"""
    def __init__(self, vocab_size, num_hiddens, device,
                 get_params, init_state, forward_fn):
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = get_params(vocab_size, num_hiddens, device)
        self.init_state, self.forward_fn = init_state, forward_fn

    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)

    def begin_state(self, batch_size, device):
        return self.init_state(batch_size, self.num_hiddens, device)

然后利用循环,根据语句长度做预测判断,损失函数计算优化

def predict_ch8(prefix, num_preds, net, vocab, device):  #@save
    """在prefix后面生成新字符"""
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))
    for y in prefix[1:]:  # 预热期
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):  # 预测num_preds步
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

2.3 代码简洁实现

往往通过一个nn.RNN来实现

nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)

参数说明

input_size输入特征的维度, 一般rnn中输入的是词向量,那么 input_size 就等于一个词向量的维度
hidden_size隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态)
num_layers网络的层数,一般可以默认为1
nonlinearity激活函数
bias是否使用偏置
batch_first输入数据的形式,默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位
dropout是否应用dropout, 默认不使用,如若使用将其设置成一个0-1的数字即可
birdirectional是否使用双向的 rnn,默认是 False
注意某些参数的默认值在标题中已注明

rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens, )

定义模型, 其中vocab_size = 1027, hidden_size = 256

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1388792.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Fpga开发笔记(二):高云FPGA发开发软件Gowin和高云fpga基本开发过程

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/135620590 红胖子网络科技博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬…

36V/1.6A两通道H桥驱动芯片-SS8812T可替代DRV8812

由工采网代理的SS8812T是一款双通道H桥电流控制电机驱动器;每个 H 桥可提供输出电流 1.6A,可驱动两个刷式直流电机,或者一个双极步进电机,或者螺线管或者其它感性负载;双极步进电机可以以整步、2 细分、4 细分运行&…

旧路由重置新路由设置新路由设置教程|适用于自动获取IP模式

前言 如果你的光猫是直接拨号(路由模式)的,就可以按照本教程进行路由重置或者更换新路由器。 本文章适合电脑小白,请注意每一步哦! 注意事项 开始之前需要确认光猫是桥接模式还是路由模式。如果光猫是路由模式&…

❤ HbuildX使用以及快捷键

❤ HbuildX使用以及快捷键 一、HbuildX使用 HbuildX左侧项目侧边栏 点击视图 > 显示项目左侧即可 二、HBuilder X 快捷键 左移 Shift tab 右移 tab 查找 全局文件搜索:CtrlP 本文档内查找字符串:ctrlf 目录内查找字符串:ctrlaltf 替换:ctrlh 查找下一个字符串:f3 查…

表的增删改查CURD(一)

🎥 个人主页:Dikz12🔥个人专栏:MySql📕格言:那些在暗处执拗生长的花,终有一日会馥郁传香欢迎大家👍点赞✍评论⭐收藏 目录 新增(Create) 全列插入 指定列…

SDK游戏盾是什么?,sdk游戏盾有什么作用

在现今的游戏市场,游戏保护成为了每个游戏开发者都不能忽视的重要环节。恶意破解、作弊和盗版等问题严重影响了游戏的安全性和商业价值。而如何保护自己的游戏免受这些威胁,已经成为游戏开发者们面临的重大挑战。好在SDK游戏盾,它如同保护游戏…

家用小型洗衣机哪款性价比高?好用的内衣洗衣机推荐

现在大多数的上班族,面临的都是早九晚六的工作,而且工作完下班回家还是面对各种各样的家务,特别是清洗需要换洗的洗衣,属实是有点辛苦了。可能很多人为了方便,每次洗衣服的都是把一堆衣服直接丢进洗衣机,直…

Ansible Filter滤波器的使用(一)

一、【说在前面】 Ansible Filter一般被称为滤波器或者叫过滤器。 这个东西初次听到以为是什么科学计算的东西,但是想来ansible不太可能有什么滤波操作,所以这个东西本质是一个数值筛选器,内置函数,本质是一个为了做区别化的工具…

光学 | 联合Ansys Zemax及Lumerical应对AR/VR市场挑战

当前的增强现实和虚拟现实(AR/VR)市场涵盖了广泛的应用趋势,设计人员和各企业在努力寻找非传统解决方案,以满足主流消费者不断变化的需求。 对于AR头戴设备等可穿戴解决方案,设计思路通常源于对小巧轻量化系统的需求&a…

大数据传输慢的真正原因与解决方案

随着企业数据不断增长,大数据传输已成为一项至关重要的任务。然而,许多企业在处理大数据传输时频繁遭遇传输速度慢的问题。本文将深入探讨大数据传输速度慢的根本原因,并提供一些切实有效的解决方案。 大数据传输在企业中的重要性不言而喻&am…

算法竞赛备赛进阶之数位DP训练

数位DP的思想就是对每一位进行DP,计算时记忆化每一位可以有的状态,其作用是减少运算时间,避免重复计算。 数位DP是一种计数用的DP,一般就是要统计一个区间[A,B]内满足一些条件数的个数。 以1e9甚至1e18、1e100的问题为例&#x…

Docker 容器之间的互相通信

Docker容器之间的互相通信 步骤一:创建自定义网络 首先,我们需要创建一个自定义网络,以便容器可以连接到这个网络上,从而实现互相通信。在命令行中执行以下命令: # 创建 docker network create ddz # 查看 docker n…

洛谷 P1523 旅行商简化版【线性dp+npc问题简化版】

原题链接:https://www.luogu.com.cn/problem/P1523 题目背景 欧几里德旅行商(Euclidean Traveling Salesman)问题也就是货郎担问题一直是困扰全世界数学家、计算机学家的著名问题。现有的算法都没有办法在确定型机器上在多项式时间内求出最优解,但是有…

2024 年企业要增强反脆弱性,IT 能够做什么?

新冠疫情被称为黑天鹅事件,而“黑天鹅”这个词的创造者纳西姆尼古拉斯塔勒布在另一本书《反脆弱:从不确定性中获益》( CSDN博主读书笔记《反脆弱:从不确定性中获益》 )中,则给出了面对随时可能出现的黑天鹅…

玖章算术NineData通过阿里云PolarDB产品生态集成认证

近日,玖章算术旗下NineData 云原生智能数据管理平台 (V1.0)正式通过了阿里云PolarDB PostgreSQL版 (V11)产品集成认证测试,并获得阿里云颁发的产品生态集成认证。 测试结果表明,玖章算术旗下NineData数据管理平台 (V1.0&#xff…

Python源码23:海龟画图turtle画小狗狗

---------------turtle源码集合--------------- Python教程43:海龟画图turtle画小樱魔法阵 Python教程42:海龟画图turtle画海绵宝宝 Python教程41:海龟画图turtle画蜡笔小新 Python教程40:使用turtle画一只杰瑞 Python教程39…

❤ Uniapp使用一(文档和 API 篇)

❤ Uniapp使用一(文档和 API 篇) 一、介绍 uni-app官网:https://uniapp.dcloud.io/api/media/image?idpreviewimage 微信小程序官网:https://developers.weixin.qq.com/miniprogram/dev/api/media/image/wx.previewImage.html …

使用WAF防御网络上的隐蔽威胁之SQL注入攻击

SQL注入攻击是一种普遍存在且危害巨大的网络安全威胁,它允许攻击者通过执行恶意的SQL语句来操纵或破坏数据库。 这种攻击不仅能够读取敏感数据,还可能用于添加、修改或删除数据库中的记录。因此,了解SQL注入攻击的机制及其防御策略对于保护网…

Spring基于AOP(面向切面编程)开发

概述 AOP为Aspect Oriented Programming的缩写,意为:面向切面编程,通过预编译方式和运行期间动态代理实现程序功能的统一维护的一种技术。AOP是OOP的延续,是软件开发中的一个热点,也是Spring框架中的一个重要内容&…

使用WAF防御之网络上的隐蔽威胁(XSS攻击)

跨站脚本攻击(XSS)是一种常见且危险的威胁。它允许攻击者在用户浏览器上执行恶意脚本,窃取信息、篡改网页内容,甚至劫持用户会话。 什么是XSS攻击 定义:XSS攻击是一种代码注入技术,攻击者通过在目标网站上…