【NLP】daydayup 循环神经网络基本结构,pytorch实现

news2024/9/25 21:23:40

RNN 循环神经网络

循环神经网络(Recurrent Neural Network,RNN)是一种神经网络结构,专门用于处理序列数据。

RNN结构原理

在这里插入图片描述

RNN架构中,网络通过循环把信息从一个处理步骤传递到下一个,这个循环结构被称为隐藏层状态或者隐藏状态。可以捕捉并储存已经出处理的序列元素信息。

这个过程可以简化为
s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(\mathbf{U}·x_t+\mathbf{W}·s_{t-1}) st=f(Uxt+Wst1)
U是输入到隐藏的权重矩阵

W是隐藏到隐藏的权重

在这里插入图片描述

输出层 O t {O}_{t} Ot = g(V s t {s}_{t} st)

V是隐藏层到输出层的矩阵

在这里插入图片描述

import numpy as np
import torch
import torch.nn as nn

# 假设输入3个时间步
x = np.random.rand(3,2) 
# 一个样本的输入,如文本中的一句话,一个样本中的3个特征,一句话有3个词,每个特征的维度是2,词向量的维度是2

# 定义rnn参数
input_size = 2
hidden_size = 3
output_size = 4

# 初始化权重和偏置
W_xh = np.random.rand(input_size,hidden_size) # 输入到隐藏
W_hh = np.random.rand(hidden_size,hidden_size) # 隐藏到隐藏
W_hy = np.random.rand(hidden_size,output_size) # 隐藏到输出

bh = np.zeros(hidden_size) # 隐藏层偏置
by = np.zeros(output_size) # 输出层偏置

# 激活函数
def tanh(x):
    return np.tanh(x)

# 初始化隐藏状态
H_prev = np.zeros(hidden_size)


x1 = x[0] # 得到第一个输入特征 文本序列中的第一个词
H1 = tanh(np.dot(x1,W_xh)+H_prev+bh)
print('隐藏1:',H1)
O1 = np.dot(H1,W_hy)+by
print('输出1:',O1)

x2 = x[1]
H2 = tanh(np.dot(x2,W_xh)+np.dot(H1,W_hh)+bh)
print('隐藏2:',H2)
O2 = np.dot(H2,W_hy)+by
print('输出2:',O2)

x3 = x[1]
H3 = tanh(np.dot(x3,W_xh)+np.dot(H2,W_hh)+bh)
print('隐藏3:',H3)
O3 = np.dot(H3,W_hy)+by
print('输出3:',O2)

RNNcell

PyTorch循环神经网络

import torch
import torch.nn as nn

x = torch.randn(10,6,5) # 10批次大小 6词数 5向量维度
# 一次输入10句话,一句话中有6个词(特征),词向量维度是5(特征维度)

class RNN(nn.Module):
    def __init__(self,input_size,hidden_size,batch_first=True):
        # input_size 输入的词向量维度,特征维度
        # hidden_size 隐藏状态的张量维度
        # batch_first 第一维度是否是batch,如果是,需要维度转换,以符合RNNcell的输入
        super().__init__()
        self.rnn_cell = nn.RNNCell(input_size,hidden_size)
        self.hidden_size = hidden_size
        self.batch_first = batch_first

    def __initialize_hidden(self,batch_size):
        # 初始化隐藏状态  第一个时间步没有隐藏的输入,需要初始化
        return torch.zeros((batch_size,self.hidden_size))

    def forward(self,x,init_hidden=None):

        # 得到数据的各个维度
        if self.batch_first:  # 维度转换 以符合cell输入
            bach_size,seq_size,input_size = x.size()

            x = x.permute(1,0,2)
        else:
            seq_size,bach_size,input_size = x.size()

        hiddens = [] # 储存隐藏状态

        if init_hidden is None: # 如果是第一个输入
            init_hidden = self.__initialize_hidden(bach_size)
            init_hidden = init_hidden.to(x.device) # 同步设备

        hidden_t = init_hidden

        for t in range(seq_size):

            hidden_t = self.rnn_cell(x[t],hidden_t)

            hiddens.append(hidden_t)

        hiddens = torch.stack(hiddens) # 堆叠所有时间步隐藏输出,合并为一个张量

        if self.batch_first:
            hiddens = hiddens.permute(1,0,2)

        print(hiddens)

        return hiddens

model = RNN(5,8) # imput_size 词向量的维度 hidden_size 输出的维度  隐藏状态的张量维度
outputs = model(x)
print(outputs.shape) # torch.Size([10, 6, 8])

**这里并没有进行out的输出,只是获得了隐藏状态,在实际的需求中,需要增加其他的结构如线性层对隐藏状态进行操作 **

RNN

基于pytorch实现

import torch
import torch.nn as nn

# 超参数设置

batch_size,seq_size,input_size = 10,6,5 # 批次 句子长度 词向量维度

hidden_size = 3  # 隐藏状态的张量维度

# 数据
x = torch.rand(batch_size,seq_size,input_size)

# 初始化隐藏状态
h_prev = torch.zeros(batch_size,hidden_size)

# 创建RNN

rnn = nn.RNN(input_size, hidden_size,batch_first=True) # batch_first=True是否转化

out, hide= rnn(x,h_prev.unsqueeze(0))  # 返回值 第一个值为输出  第二个值是状态信息

print(out.shape) # torch.Size([10, 6, 3])
print(hide.shape) # torch.Size([1, 10, 3])

biRNN双向RNN

双向RNN,使得模型能够学习到序列中某一点前后的上下文信息

在这里插入图片描述

import torch
import torch.nn as nn

# 超参数设置

batch_size,seq_size,input_size = 10,6,5 # 批次 句子长度 词向量维度

hidden_size = 3  # 隐藏状态的张量维度

# 数据
x = torch.rand(batch_size,seq_size,input_size)

# 初始化隐藏状态
h_prev = torch.zeros(batch_size,hidden_size)

# 创建RNN

rnn = nn.RNN(input_size, hidden_size,batch_first=True,bidirectional=True) # batch_first=True是否转化

out, hide= rnn(x)  # 返回值 第一个值为输出  第二个值是状态信息

print(out.shape) # torch.Size([10, 6, 6])  这里直接合并了双向的隐藏状态
print(hide.shape) # torch.Size([2, 10, 3]) 输出的是正向和反向的隐藏状态

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2164826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

10个强大的AI驱动售后服务,助力成功

目录 10个AI驱动售后服务 1. 预测性维护以减少停机时间2. 自动化客户支持3. 个性化产品推荐4. 情感分析以改进反馈管理5. AI驱动自助服务平台6. 主动客户沟通7. 智能保修管理8. AI增强服务团队培训9. 高级分析服务优化10. AI驱动忠诚度计划 利用AI提升售后服务体验 在当今竞争…

探索OpenAI的全新里程碑:o1模型

近期,人工智能领域迎来了一项重要突破——OpenAI发布了其最新的语言模型o1。作为一款专为解决复杂问题设计的新一代大语言模型(LLM),o1标志着该公司在智能推理能力方面迈出了重要的一步。尽管这个新系统仍处于初步阶段&#xff0c…

分析二极管的交流响应(1)——直流分析,Q点的计算

二极管的直流电路分析我们可以用理想模型,恒压降模型和折线模型去近似分析,但是这些模型仅限于我们的信号是直流的情况。如果遇到交流信号,我们该如何去分析呢? 首先我们来理解Q点的概念: 看这个Q点里的“Q”是个什么…

【C++】C++中如何处理多返回值

十四、C中如何处理多返回值 本部分也是碎碎念,因为这些点都是很小的点,构不成一篇文章,所以本篇就是想到哪个点就写哪个点。 1、C中如何处理多个返回值 写过python的同学都知道,当你写一个函数的返回时,那是你想返回…

【Javascript】原生实现deep watch,使用proxy逐层建立数据监听

原理 使用 proxy对象处理数据,添加监听,然后递归再次添加直到全部添加完毕 代码 /*** 给对象递归建立数据监听,可以监测每一层的每个键的变化* * param {*} obj // 目标对象 * param {*} callback //回调。通过key处理对应的变化* param {…

机器学习EDA探查工具Pandas profiling

在最初的数据探查的时候,可以通过pandas的函数,以及matplotlib做图像绘图,这个工作比较重复和低效,所以pandas针对常用的数据列统计和展示,做了EDA工具profiling,可以自动帮助数据分析。 问题1&#xff1a…

java核心基础

文章目录 1. Java开发基础1.1 DOS常用命令:(以MAC常用命令比较)1.2 JVM、JRE、JDK之间的关系1.3 Java开发环境的搭建1.4 Java的注释,标识符、标识符的命名规范1.5 变量和常量的定义及初始化1.6 Java的运算符1.7 三大语句1.8 常用的类1.8.1 ja…

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-21

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-21 1. AIvril: AI-Driven RTL Generation With Verification In-The-Loop Authors: Mubashir ul Islam, Humza Sami, Pierre-Emmanuel Gaillardon, and Valerio Tenace AIVRIL: 人工智能驱动的RTL生成与验证内…

OpenAPI鉴权(二)jwt鉴权

一、思路 前端调用后端可以使用jwt鉴权;调用三方接口也可以使用jwt鉴权。对接多个三方则与每个third parth都约定一套token规则,因为如果使用同一套token,token串用可能造成权限越界问题,且payload交叉业务不够清晰。下面的demo包…

springBoot --> 学习笔记

文章目录 认识 SpringBoot第一个 SpringBoot 程序彩蛋 banner (emmmmm,哈哈哈哈哈哈,牛逼!)SpringBoot 配置配置文件第一个 yaml 配置 成功案例yaml 存在 松散绑定 JSR 303 数据校验多环境配置以及文件位置访问静态资源…

教你制作一个二维码就能查分的系统

学生和家长对于成绩查询的需求日益增长。为了满足这一需求,很多学校和老师开始使用二维码查询系统,以提高效率和保护隐私。以下内容就是如何制作一个简单易用的成绩查询二维码系统的步骤: 1. 准备电子表格 老师需要准备一个包含学生成绩的电…

(已解决)vscode如何传入argparse参数来调试/运行python程序

文章目录 前言调试传入参数运行传入参数延申 前言 以前,我都是用Pycharm专业版的,由于其好像在外网的时候,不能够通过VPN来连接内网服务器,我就改用了vscode。改用了之后,遇到一个问题,调试或者运行python…

基于Qt5.12.2开发 MQTT客户端调试助手

项目介绍 该项目是一个基于 Qt 框架开发的桌面应用程序,主要用于与 MQTT 服务器进行连接和通信。通过该应用,用户可以连接到 MQTT 服务器,订阅主题、发布消息并处理接收到的消息。项目使用 QMqttClient 类来实现 MQTT 协议的客户端功能&…

第128集《大佛顶首楞严经》

《大佛顶如来密因修正了义诸菩萨万行首楞严经》。监院法师慈悲,诸位法师,诸位同学,阿弥陀佛! 请大家打开讲义296面。 庚一、总示阴相(分四:辛一、结前行阴尽相。辛二、正明识阴区宇。辛三、悬示识阴尽相。…

通过frp 免费内网穿透,端口转发

1.准备工作 (1)拥有一台有公网IP的服务器(系统可以是windows/macos/linux),服务器可以使用云厂商购买的服务器 (2)从下面链接下载最新版本的frp安装包,客户端和服务端是同一个tar包 https://github.com/fatedier/frp/releases 服务端机器A-有外网ip的作为服务端 服务端机器B-需…

前端接口415状态码【解决】

前端接口415状态码【解决】 一、概述 415状态码是HTTP协议中的一个标准响应状态码,代表“Unsupported Media Type”(不支持的媒体类型)。当客户端尝试上传或发送一个服务器无法处理的媒体类型时,服务器会返回这个状态码。这通常意…

二维四边形网格生成算法:paving(五)缝合 Seaming 与 闭合检测 Closure Check

欢迎关注更多精彩 关注我,学习常用算法与数据结构,一题多解,降维打击。 参考论文:Paving: A new approach to automated quadrilateral mesh generation 关注公众号回复paving可以获得文章链接 paving(一&#xff0…

python如何将字符转换为数字

python中的字符数字之间的转换函数 int(x [,base ]) 将x转换为一个整数 long(x [,base ]) 将x转换为一个长整数 float(x ) 将x转换到一个浮点数 complex(real [,imag ]) 创建一个复数 str(x ) 将对象 x 转换为字…

Pytest测试实战|执行常用命令

Pytest测试实战 本文章主要详细地阐述下Pytest测试框架执行TestCase常用命令。 按分类执行 在Pytest测试框架中按照分类执行的命令为“-k”,它的主要特点是按照TestCase名字的模式来执行,在编写具体的TestCase的时候,都会编写每个TestCase…

el-table表格点击该行任意位置时也勾选上其前面的复选框

需求&#xff1a;当双击表格某一行任意位置时&#xff0c;自动勾选上其前面的复选框 1、在el-table 组件的每一行添加row-dblclick事件&#xff0c;用于双击点击 <el-table:data"tableData"ref"tableRef"selection-change"handleSelectionChange&q…