[自然语言处理]RNN

news2024/11/17 14:52:37

1 传统RNN模型与LSTM

import torch
import torch.nn as nn

torch.manual_seed(8)


def dm01():
    '''
    参数1:输入向量的维数
    参数2:隐藏层神经元的个数
    参数3:隐藏层的层数
    :return:
    '''
    rnn = nn.RNN(5, 6, 1)
    '''
    参数1:句子长度sequence_length
    参数2:一个批次的样本数量batch_size
    参数3:每个单词的向量维数vector_dim
    '''
    input = torch.randn(1, 3, 5)
    '''
    参数1:隐藏层的层数
    参数2:一个批次的样本数量batch_size
    参数3:隐层层神经元个数 
    '''
    h0 = torch.randn(1, 3, 6)
    output, hn = rnn(input, h0)
    print(f'output-->{output.shape} {output}')
    print(f'hn-->{hn.shape} {hn}')
    print(f'rnn模型-->{rnn}')


def dm02():
    rnn = nn.RNN(5, 6, 1)
    input = torch.randn(4, 3, 5)
    h0 = torch.randn(1, 3, 6)
    output, hn = rnn(input, h0)
    print(f'output-->{output.shape} {output}')
    print(f'hn-->{hn.shape} {hn}')
    print(f'rnn模型-->{rnn}')


def dm03():
    rnn = nn.RNN(5, 6, 1)
    input = torch.randn(4, 1, 5)
    print(f'input {input}')
    hidden = torch.zeros(1, 1, 6)
    # 一个一个地送字符
    for i in range(4):
        tmp = input[i][0]
        print(f'tmp.shape {tmp.shape}')
        output, hidden = rnn(tmp.unsqueeze(0).unsqueeze(0), hidden)
        print(f'{i} {output}')
        print(f'{i} {hidden}')
        print('*' * 80)

    hidden = torch.zeros(1, 1, 6)
    output, hn = rnn(input, hidden)
    print(f'output2 {output} {output.shape}')
    print(f'hn {hn} {hn.shape}')


# 改变隐藏层数
def dm04():
    rnn = nn.RNN(5, 6, 2)
    input = torch.randn(4, 3, 5)
    h0 = torch.randn(2, 3, 6)
    output, hn = rnn(input, h0)
    print(f'output-->{output.shape} {output}')
    print(f'hn-->{hn.shape} {hn}')
    print(f'rnn模型-->{rnn}')


# 改变batch_size参数
def dm05():
    rnn = nn.RNN(5, 6, 1, batch_first=True)
    input = torch.randn(3, 4, 5)
    h0 = torch.randn(1, 3, 6)
    output, hn = rnn(input, h0)
    print(f'output-->{output.shape} {output}')
    print(f'hn-->{hn.shape} {hn}')
    print(f'rnn模型-->{rnn}')


# LSTM
def dm06():
    rnn = nn.LSTM(5, 6, 2)
    input = torch.randn(1, 3, 5)
    h0 = torch.randn(2, 3, 6)
    c0 = torch.randn(2, 3, 6)
    output, (hn, cn) = rnn(input, (h0, c0))
    print(f'output {output}')
    print(f'hn {hn}')
    print(f'cn {cn}')


if __name__ == '__main__':
    # dm01()
    # dm02()
    # dm03()
    # dm04()
    # dm05()
    dm06()
D:\nlplearning\nlpbase\python.exe D:\nlpcoding\rnncode.py 
output tensor([[[ 0.0207, -0.1121, -0.0706,  0.1167, -0.3322, -0.0686],
         [ 0.1256,  0.1328,  0.2361,  0.2237, -0.0203, -0.2709],
         [-0.2668, -0.2721, -0.2168,  0.4734,  0.2420,  0.0349]]],
       grad_fn=<MkldnnRnnLayerBackward0>)
hn tensor([[[ 0.1501, -0.2106,  0.0213,  0.1309,  0.3074, -0.2038],
         [ 0.3639, -0.0394, -0.1912,  0.1282,  0.0369, -0.1094],
         [ 0.1217, -0.0517,  0.1884, -0.1100, -0.5018, -0.4512]],

        [[ 0.0207, -0.1121, -0.0706,  0.1167, -0.3322, -0.0686],
         [ 0.1256,  0.1328,  0.2361,  0.2237, -0.0203, -0.2709],
         [-0.2668, -0.2721, -0.2168,  0.4734,  0.2420,  0.0349]]],
       grad_fn=<StackBackward0>)
cn tensor([[[ 0.2791, -0.7362,  0.0501,  0.2612,  0.4655, -0.2338],
         [ 0.7902, -0.0920, -0.4955,  0.3865,  0.0868, -0.1612],
         [ 0.2312, -0.3736,  0.4033, -0.1386, -1.0151, -0.5971]],

        [[ 0.0441, -0.2279, -0.1483,  0.3397, -0.5597, -0.4339],
         [ 0.2154,  0.4119,  0.4723,  0.4731, -0.0284, -1.1095],
         [-0.5016, -0.5146, -0.4286,  1.5299,  0.5992,  0.1224]]],
       grad_fn=<StackBackward0>)

Process finished with exit code 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2200776.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

九芯电子NVH/NVF语音芯片OTA升级操作方法

OTA&#xff08;Over-The-Air&#xff09;升级是指通过无线网络远程对设备进行软件升级的过程。对于九芯电子NVH/NVF语音芯片&#xff0c;OTA升级可以通过WiFi模组实现&#xff0c;支持MQTT、HTTP等协议&#xff0c;方便快捷‌。 具体操作步骤如下&#xff1a; 1.进入九芯“智…

计算机毕业设计 基于Django的学生选课系统的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

处理Java内存溢出问题(java.lang.OutOfMemoryError):增加JVM堆内存与调优

处理Java内存溢出问题&#xff08;java.lang.OutOfMemoryError&#xff09;&#xff1a;增加JVM堆内存与调优 在进行压力测试时&#xff0c;遇到java.lang.OutOfMemoryError: Java heap space错误或者nginx报错no live upstreams while connecting to upstream通常意味着应用的…

重头开始嵌入式第四十七天(硬件 ARM裸机开发 RS232 RS4885 IIC)

目录 一.什么是RS232&#xff1f; 1. 历史背景&#xff1a; 2. 电气特性&#xff1a; 3. 连接器类型&#xff1a; 4. 通信特点&#xff1a; 5. 应用场景&#xff1a; 二.什么是RS485&#xff1f; 1. 电气特性&#xff1a; 2. 通信模式&#xff1a; 3. 传输距离与速率&…

技术路线图用什么画?用这个在线工具轻松完成绘制!

在当今快速发展的技术世界中&#xff0c;技术路线图已成为企业和团队不可或缺的战略规划工具。它不仅能够清晰地展示技术发展方向&#xff0c;还能帮助团队成员、利益相关者和投资者更好地理解和参与技术战略的制定过程。但不可否认的是&#xff0c;创建一个有效的技术路线图并…

如何免费为域名申请一个企业邮箱

背景 做SEO的是有老是会有一些网站来做验证你的所有权&#xff0c;这个时候&#xff0c;如果你域名对应的企业邮箱就会很方便。zoho为了引导付费&#xff0c;有很多多余的步骤引导&#xff0c;反倒是让不付费的用户有些迷茫&#xff0c;所以会写这个教程&#xff0c;按照教程走…

虚幻引擎GAS入门学习笔记(二)

虚幻引擎GAS入门(二) 学习位置UE5.3 GAS入门教程重置版 小明 MVC框架与技能初始化 让一开始创建的蓝图的基础GameplayAbility蓝图继承我们写好的BaseGameplayAbility类 创建一个库函数&#xff0c;写一些常用的函数在里面第一个得到玩家与玩家控制器 获取角色面对目标的方向…

c++11~c++20 thread_local

线程局部存储是指对象内存在线程开始后分配&#xff0c;线程结束时回收且每个线程有该对象自己的实例&#xff0c;简单地说&#xff0c;线程局部存储的对象都是独立各个线程的。实际上这并不是一个新鲜个概念&#xff0c;虽然C一直没因在语言层面支持它&#xff0c;但是很早之前…

Coggle数据科学 | 全球AI攻防挑战赛:金融场景凭证篡改检测 baseline

本文来源公众号“Coggle数据科学”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;全球AI攻防挑战赛&#xff1a;金融场景凭证篡改检测 baseline 赛题名称&#xff1a;全球AI攻防挑战赛—赛道二&#xff08;AI核身-金融场景凭证篡改…

集智书童 | FMRFT 融合Mamba和 DETR 用于查询时间序列交叉鱼跟踪 !

本文来源公众号“集智书童”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;FMRFT 融合Mamba和 DETR 用于查询时间序列交叉鱼跟踪 ! 鱼的生长、异常行为和疾病可以通过图像处理方法进行早期检测&#xff0c;这对工厂水产养殖具有重…

基于云效流水线Flow | 高效构建企业门户网站

基于云效流水线Flow | 高效构建企业门户网站 基于云效流水线Flow | 高效构建企业门户网站企业门户网站方案架构一键部署方案概览部署准备一键部署 部署服务端&#xff08;云效流水线&#xff09;添加流水线源Java构建上传主机部署 资源删除操作体验1&#xff09; 在体验过程中是…

Redis 5 种基本数据类型的前两个详解

Redis 共有 5 种基本数据类型&#xff1a;String&#xff08;字符串&#xff09;、List&#xff08;列表&#xff09;、Set&#xff08;集合&#xff09;、Hash&#xff08;散列&#xff09;、Zset&#xff08;有序集合&#xff09;。 这 5 种数据类型是直接提供给用户使用的&…

qos在企业网中的设计与实现

1.拓扑 地址规划 业务地址规划 部门 地址空间 vlan 网关 市场部门 10.0.100.0/24 Vlan100 10.0.100.254/24 研发部门 10.0.101.0/24 Vlan101 10.0.101.254/24 财务部门 10.0.102.0/24 Vlan102 10.0.102.254/24 人力部门 10.0.103.0/24 Vlan103 10.0.103.25…

[nmap] 端口扫描工具的下载及详细安装使用过程(附有下载文件)

nmap网络连接端扫描软件&#xff0c;用于主机发现、端口扫描、版本侦测、操作系统侦测 下载链接在文末 下载压缩包后解压 &#xff01;&#xff01;安装路径不要有中文 解压得到文件 双击.exe文件 更改安装路径并点击安装 等待安装 安装完成 nmap-7.95-setup.zip 夸克网盘打开…

pip install kaggle-environments ISSUE:Failed to build vec-noise

ISSUE: error: Microsoft Visual C 14.0 or greater is required. Get it with “Microsoft C Build Tools”: https://visualstudio.microsoft.com/visual-cpp-build-tools/ [end of output]Failed to build vec-noiseC:\ProgramData\miniconda3\include\pyconfig.h(59): fat…

基于Springboot+Vue的家校互联系统(含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统中…

信息安全工程师(41)VPN概述

前言 VPN&#xff0c;即Virtual Private Network&#xff08;虚拟专用网络&#xff09;的缩写&#xff0c;是一种通过公共网络&#xff08;如互联网&#xff09;创建私密连接的技术。 一、定义与工作原理 定义&#xff1a;VPN是依靠ISP&#xff08;Internet Service Provider&…

国庆档不太热,影视股“凉”了?

今年国庆档票房止步21亿元&#xff0c;属实有点差强人意。 根据国家电影局统计&#xff0c;2024年国庆档&#xff08;2024年10月1日至7日&#xff09;全国电影票房为21.04亿元&#xff0c;观影人次为5209万&#xff0c;总票房成绩、观影总人次同比均有所下滑。 作为传统观影高…

AS-REP Roasting 实验

1. 实验网络拓扑 kali: 192.168.72.128win2008: 192.168.135.129 192.168.72.139win7: 192.168.72.149win2012:(DC) 192.168.72.131 2. 攻击原理 如果设置了不需要Kerberos预认证&#xff1a; 那么就可以直接发AS_REQ请求TGT票据&#xff0c;由于不要求预身份认证&#xff0…

FLORR.IO画廊(3)

锯齿&#xff08;超级&#xff09; 是florr.io的一种辅助型花瓣&#xff0c;可以用于提升碰撞伤害。玩家装备后&#xff0c;外观会显示出一圈转动的齿轮&#xff0c;就像digdig.io中的玩家一样。不堆叠 圆盘&#xff08;超级&#xff09; 是Florr.io的一种削伤型花瓣&#xff…