【深度学习】RNN学习笔记

news2024/11/28 17:49:48

RNN学习笔记

时间序列

 将单词序列转换为向量,这里有五个单词,然后对于每一个单词都进行独热编码,编码成一个特定的向量。

在这里插入图片描述

对于RNN网络,需要一次性读取多个句子,那么涉及到batch_size,这里第二个表达就是:batch,单词,单词的表达方式

在这里插入图片描述

RNN原理

在这里插入图片描述

 这里生成一个5 x 100的向量,对于每一个单词我们都使用100个特征进行表示,然后经过每一个线性层,这100维的向量转换为更小的向量 比如五个向量特征,那么对于一个句子的每一个单词都进行这样的提取特征,那么就需要对每一个单词都进行线性层的变换,最后得到一个5x5的向量

但是上面的做法有两个缺点:

  • 单词数量太多 句子太长
  • 参数两太多 对于每一个单词都需要生成特定的w b的线性层 这样参数过多
  • 只考虑每一个单词 没有考虑语境信息

那么,对于每一个线性层,我们不仅输入每一个单词的向量,还需要输入语境信息h

在这里插入图片描述

现在举一个例子,对于每一个单词我们使用 100个特征进行表示,对于一个句子五个单词就是5 x 100的向量表示,那么我们选择batch_size是3 一次输入三个单词,那么对于每次的输入xt,就是[3,100]的向量,我们初始化语境信息向量h[0,0…],输入网络

在这里插入图片描述

在这里插入图片描述

针对上面的讲述,整理成公式化,对于每一个输出ht 都是由ht-1 和xt得到的,也就是当前层的输入xt和上一层计算的输出作为当前层的输入,然后,使用Whh和WXh分别对输入的xt和ht-1进行特征提取,得到的输出向量经过激活函数进行激活

在这里插入图片描述

如何推导出梯度?

为RNN网络的权重whh和Wxh都是共享的,那么最后的输出一定是关系到每一个时刻的权重,所以我们需要使用损失对每一个时刻的权重进行求导,然后进行累加

在这里插入图片描述

在这里插入图片描述

对于ht求导hi

在这里插入图片描述

RNN Layer的使用

 输入的向量xt是[3,100],那么 Wxh是[100,20]的形状的参数向量,也就是说相乘之后变成了[3,20] 原先单词使用100个特征进行表示,现在使用20个特征进行表示。同样的对于语境信息向量ht[3,20] 那么Whh是[20,20]的向量,想成之后变成了[3,20],然后两个[3,20]的向量相加得到ht+1向量

在这里插入图片描述

rnn=nn.RNN(100,10)  #每个单词用100维表示,memory/hidden为10维
rnn._parameters.keys()
#out:odict_keys(['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'])
 
rnn.weight_hh_l0.shape,rnn.weight_ih_l0.shape
#out:(torch.Size([10, 10]), torch.Size([10, 100]))
 
rnn.bias_hh_l0.shape,rnn.bias_ih_l0.shape
out:(torch.Size([10]), torch.Size([10]))

在这里插入图片描述

这里创建了一个RNN模型,其中参数100表示输入数据的特征维度,每个单词用100维的向量表示;参数10表示RNN内部隐层的大小,也称为记忆维度,即隐状态的维度。

nn.RNN模块有四个可训练参数,它们是:

weight_ih_l0: 输入到隐层的权重矩阵,形状为(10, 100),表示从输入层到隐层的权重参数。
weight_hh_l0: 隐层到隐层的权重矩阵,形状为(10, 10),表示从上一时刻隐层到当前时刻隐层的权重参数。
bias_ih_l0: 输入到隐层的偏置向量,形状为(10),表示从输入层到隐层的偏置参数。
bias_hh_l0: 隐层到隐层的偏置向量,形状为(10),表示从上一时刻隐层到当前时刻隐层的偏置参数。

输入到隐层的权重矩阵(weight_ih_l0)的形状为 (hidden_size, input_size),而不是 (input_size, hidden_size)。这里 (10, 100) 表示有 10 个隐状态单元(记忆单元)和每个单词用 100 维的向量表示作为输入特征。

单层的RNN

因为是单层的RNN,那么h向量只有一个

# RNN创建
rnn = nn.RNN(input_size = 100,hidden_size = 20,num_layers=1)
print(rnn)

#  10代表每一个句子的单词数量  3代表三个句子  100 代表每一个单词的维度
x = torch.randn(10,3,100)

#  第一个参数是输入向量x 第二个参数是h向量  只有一个单词 每一层三个句子 每一个单词使用20个维度
out,h = rnn(x,torch.zeros(1,3,20))

# 最后的输出out 是一个 10 x 3 x 20 的向量 也就是被提取成20个维度  然后h是一个1 3 20的向量
print(out.shape,h.shape)

多层的RNN

因为有多层RNN,那么就有多个h向量,但是out输出形状不会变

在这里插入图片描述

rnn = nn.RNN(input_size = 100,hidden_size = 20,num_layers = 4)
print(rnn)

x = torch.randn(10,3,100)

# 这里面的h 是多层的 应该是[4,3,20]
out,h = rnn(x)

print(out.shape,h.shape)

单层RNNCell

 处理三个句子,每一个句子10个单词,每一个单词使用长度100的向量,送入RNN的shape就是[10,3,100]

 如果使用RNNCell 针对每一个时刻都分开处理,这里十个单词就是十个时刻,每次输入的向量是[3,100] 那么计算单元运行十次,

显然,RNNCelll没办法想RNN直接求出网络的输出,那么只需要将最后一层每一个时刻得输出h组合起来。out = torch.stack([h1,h2,…,ht])


#  100代表输入得特征维度  20代表提取的特征维度
cell = nn.RNNCell(100,20)

# 初始化输入  某一个时刻得输入
x = torch.randn(3,100)

# 初始化所有时刻得输入
xs = [torch.randn(3,100) for i in range(10)]

# 初始化隐藏层记忆单元
h = torch.zeros(3,20)


#  针对每一个时刻得输入 传入RNN单元
for xt in xs:
    h = cell(xt,h)

# 查看最后的输出
print(h.shape)


多层RNNCell


#  定义两层计算单元
cell_l0 = nn.RNNCell(100,30)
cell_l1 = nn.RNNCell(30,20)


# 定义两层的隐藏单元
h_l0 = torch.zeros(3,30)
h_l1 = torch.zeros(3,20)


# 初始化原始输入
xs = [torch.randn(3,100) for i in range(4)]

for xt in xs:
    h_l0 = cell_l0(xt,h_l0)
    h_l1 = cell_l1(h_l0,h_l1)


print(h_l0.shape)
print(h_l1.shape)

简单案例

预测正弦曲线的下一段波形

因为是输入一段波形曲线,不同于一个句子,句子中的每一个单词都是不能直接输入,需要做embedding,也就是对每一个单词进行编码成一个向量,但是波形曲线每一个点都是一个数字,不需要做embedding,那么如果给出五十个点,只提供一条曲线,那么seq_len = 50,feature_len =1,那么输入就是[50,1,1],但是batch需要提前,也就是输入[1,50,1],但是因为最后的输出是[seq_len,batch,hidden_len] 我们还是需要添加一个线性层变换一下,不能是hidden_len,需要变换成1 也就是[seq_len,batch,1]

  • 定义网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # RNN
        self.rnn = nn.RNN(
            input_size = 1, # feature_len = 1
            hidden_size = 16, # 隐藏层记忆单元尺寸
            num_layers = 1,# 层数
            batch_first = True  # 确保输入时 按照[batch,seq_len,feature_len]的模式 1 x 50 x 1
        )

        #  对RNN 进行参数初始化
        for p in self.rnn.parameters():
            nn.init.normal_(p,mean=0.0,std=0.001)

        #  输出层 直接使用一个线性变换 吧每一个时刻的记忆单元的Hidden_len的输出为所需要的feature_len = 1 因为是一个数据点
        #  50 x 16 -> 50 x 1
        self.linear = nn.Linear(16,1)


    def forward(self,x,h):
        # 这里的out原始形状是 1 x 50 x 16 最后变成 50 x 16
        # h的原始形状是 1x 1  x 16  
        out,h = self.rnn(x,h)
        # 因为最后需要传递给线性层处理,所以需要展平
        out = out.view(-1,16)
        out = self.linear(out)

        # 在把batch维度添加进去
        out = out.unsqueeze(dim=0)

        return out,h

  • 训练网络
lr = 0.01 
model = Net()
from torch import nn, optim
criterion = nn.MSELoss()  # 均方差损失函数
optimizer = optim.Adam(model.parameters(),lr) # 学习率

# 初始化记忆单元 
h = torch.zeros(1,1,16)

# 生成样本数据
num_points = 50
seq_len = num_points - 1

import numpy as np
for iter in range(6000):
    k = np.random.randint(3, size=1)[0]
    # 取点的区间是[k, k+10],均匀地取num_points个点
    time_steps = np.linspace(k, k + 10, num_points)
    # 在这num_points个时刻上生成函数值数据
    data = np.sin(time_steps)
    # 将数据从shape=(num_points,)转换为shape=(num_points,1)
    data = data.reshape(num_points, 1)  # feature_len=1
    # 输入前49个点(seq_len=49),即下标0~48
    x = torch.tensor(data[:-1]).float().view(1, seq_len, 1)  # batch,seq_len,feature_len
    # 预测后49个点,即下标1~49
    y = torch.tensor(data[1:]).float().view(1, seq_len, 1)  # batch,seq_len,feature_len

    # 至此,生成了x->y的样本对, x和y都是shape如上面所写的序列
    
    # 将数据输入
    out,h = model(x,h)
    # h在循环中被
    h = h.detach()

    # 计算和预期输出之间的损失
    loss = criterion(out,y)

    # 更新网络参数
    model.zero_grad()

    # 计算梯度
    loss.backward()
    # 优化
    optimizer.step()


    if iter % 1000 == 0:
        print("迭代次数:{}, loss:{}".format(iter + 1, loss.item()))



  • 测试
from matplotlib import pyplot as plt
# 先用同样的方式生成一组数据x,y
k = np.random.randint(3, size=1)[0]
time_steps = np.linspace(k, k + 10, num_points)
data = np.sin(time_steps)
data = data.reshape(num_points, 1)  # feature_len=1
x = torch.tensor(data[:-1]).float().view(1, seq_len, 1)  # batch,seq_len,feature_len
y = torch.tensor(data[1:]).float().view(1, seq_len, 1)  # batch,seq_len,feature_len

# 用于记录预测出的点
predictions = []

# 取训练时输入的第一个点,即在x(1,seq_len,1)取seq_len里面第0号的数据
# 这里将输入seq_len'设置为1(而不是49)
# 输入什么长度的数据会自动调整网络结构来给出输出
input = x[:, 0, :]
# 输入的shape变成标准的(batch=1,seq_len'=1,feature_len=1)
input = input.view(1, 1, 1)

# 迭代seq_len次,每次预测出一个点
for _ in range(x.shape[1]):
    # 送入模型得到预测的序列,输入了一个点的序列也就输出了(下)一个点的序列
    pred, h = model(input, h)
    # 这里将预测出的(下一个点的)序列pred当成输入,来给到下一次循环
    input = pred
    # 把里面那个点的数取出来记录到列表里
    # 这里用ravel()而不用flatten(),因为后者是原地操作,会改变pred也就是input
    predictions.append(pred.detach().data.numpy().ravel()[0])

# 绘制预测结果predictions和真实结果y的比较
plt.scatter(time_steps[1:], y.data.numpy().ravel())
plt.scatter(time_steps[1:], predictions, c='r')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/779780.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue上传图片返回base64并在页面展示,并图片上canvas进行红框框选标记

https://www.cnblogs.com/szqtiger/p/12100754.html vue如何显示base64图片_vue显示base64_不断学习的码农的博客-CSDN博客 图片上进行红框框选_时小帅的博客-CSDN博客 设置canvas画布大小_canvas设置画布大小_最凶残的小海豹的博客-CSDN博客 图片回显 结合以上&#xff0…

MySQL8 新特性——窗口函数用法

MySQL8 新特性——窗口函数用法 MySQL 8.0 是 MySQL 数据库管理系统的一个重要版本,引入了许多新特性和改进。以下是 MySQL 8.0 的一些主要新特性: 事务隔离级别改进: MySQL 8.0 引入了新的事务隔离级别 SERIALIZABLE,提供了最高…

妙记多 Mojidoc PC端(Mac 端+windows端)Beta版本正式上线!

你们呼唤了无数次的妙记多 Mojidoc PC客户端 Beta版本正式上线啦! 感谢300位妙友积极参与内测,给予了我们很多非常有效的意见和建议!我们会根据用户反馈不断优化和修复相关功能,在此感谢妙友们一直以来的支持~ PC端拥…

静态html引入ucharts并直接使用组件标签

由于官方不能直接使用qiun-vue-ucharts在静态html页面使用。 DIY可视化对此类库进行了改进,把它的包独立打包成一个可以依赖的JS。 首先定义一个核心JS,用于打包生成uchart import qiunVueUcharts from qiun/vue-ucharts;const install (app) > {…

el-select实现el-option可编辑

鼠标悬浮出现编辑图标 点击编辑图标对选择项进行修改 核心代码如下&#xff0c;注意el-input不要使用focus&#xff0c;会导致el-select面板收起来&#xff1b;使用click.native.stop即可 <el-select v-model"value" placeholder"选择" style"widt…

酷雷曼无人机技能培训考试圆满举办

2023年7月18日、19日&#xff0c;以“向云端起航&#xff0c;让技术落地”为主题的酷雷曼无人机技能提升培训会在酷雷曼北京运营中心隆重举行&#xff0c;来自全国各地的众多合作商参加了本次培训&#xff0c;通过系统、全面的学习成功取得了专业无人机飞行员执照&#xff0c;为…

基于linux下的高并发服务器开发(第三章)- 3.10 死锁

deadlock.c #include <stdio.h> #include <pthread.h> #include <unistd.h>// 全局变量&#xff0c;所有的线程都共享这一份资源。 int tickets 1000;// 创建一个互斥量 pthread_mutex_t mutex;void * sellticket(void * arg) {// 卖票while(1) {// 加锁pt…

十、正则表达式详解:掌握强大的文本处理工具(二)

文章目录 &#x1f340;多字符匹配&#x1f340;匹配规则的代替&#x1f340;特殊的匹配&#x1f340;特殊的匹配plus&#x1f340;总结 &#x1f340;多字符匹配 星号&#xff08;*&#xff09;&#xff1a;匹配0个或者多个字符 import retext 111-222-333 result re.matc…

Cardboard for Pictures(cf)

Mircea有n张照片&#xff0c;第 i 张照片的是边长为si的正方形&#xff0c;他把每张照片都装在一块正方形的硬纸板上&#xff0c;这样每张照片的四周都有一个w厘米的硬纸板边框。他总共用了 c 立方厘米见方的硬纸板。给定图片大小和值c&#xff0c;求w。&#xff08;请注意&…

Java-API简析_java.net.InetSocketAddress类(基于 Latest JDK)(浅析源码)

【版权声明】未经博主同意&#xff0c;谢绝转载&#xff01;&#xff08;请尊重原创&#xff0c;博主保留追究权&#xff09; https://blog.csdn.net/m0_69908381/article/details/131870760 出自【进步*于辰的博客】 因为我发现目前&#xff0c;我对Java-API的学习意识比较薄弱…

Dubbogo 详解

Dubbogo 详解 简介 dubbo功能很强大的微服务开发框架&#xff0c;支持多种通信协议&#xff0c;并具有流量治理的功能。 dubbo在有了大转变&#xff0c;拥抱了云原生&#xff0c;从哪些方面可以体现呢&#xff1f; 推出了自己的Trip协议修复了服务发现的级别&#xff0c;之…

Bug竞技场【已经修复】

目录 1.基础知识 2.最佳组合 2.1 铁男-螳螂 2.2 弟弟组合 海克斯抽卡bug 1.基础知识 背景&#xff1a;美测服-美服-马服-可以有效地减少bug率 复盘是为了更好的战斗&#xff01; 提前观看一些视频资料也是如此。 通过看直播博主的经验&#xff0c;可以让你关注到本来对战的…

利用Canvas根据经纬度绘制轨迹(一)

根据经纬度坐标绘制轨迹图形 一段时间没更新了&#xff0c;主人最近有点懒~ 前段时间有个需求&#xff0c;在uniapp App端实现轨迹绘制&#xff0c;于是先用html实现看看效果~ 效果图 html <canvasid"canvasId"width"300"height"300"style&…

DOS命令(windows)

DOS命令&#xff08;windows&#xff09; 目录 1. 打开命令提示符。2. 切换至根。3. 当前路径。4. 切换至上级路径。5. 查看当前目录。6. 查看文件内容。7. 删除文件。8. 进入长文件夹名时缩写。9. 复制文件。10. 移动文件。 1. 打开命令提示符。 命令&#xff1a;winR 输入&a…

【Go】Go 语言开发工具GoLand 使用(二十二)

往期回顾&#xff1a; Go 语言教程–介绍&#xff08;一&#xff09;Go 语言教程–语言结构&#xff08;二&#xff09;Go 语言教程–语言结构&#xff08;三&#xff09;Go 语言教程–数据类型&#xff08;四&#xff09;Go 语言教程–语言变量&#xff08;五&#xff09;Go …

STM32外设系列—TB6612FNG

本文涉及到定时器和串口的知识&#xff0c;详细内容可见博主STM32速成笔记专栏。 文章目录 一、TB6612简介二、TB6612使用方法2.1 TB6612引脚连接2.2 控制逻辑2.3 电机调速 三、实战项目3.1 项目简介3.2 初始化GPIO3.3 PWM初始化3.3 电机控制程序3.4 串口接收处理函数 一、TB66…

PHP反序列化漏洞之魔术方法

一、魔术方法 PHP魔术方法&#xff08;Magic Methods&#xff09;是一组特殊的方法&#xff0c;它们在特定的情况下会被自动调用&#xff0c;用于实现对象的特殊行为或提供额外功能。这些方法的名称都以双下划线开头和结尾&#xff0c;例如: __construct()、__toString()等。 …

java项目之网络视频播放器(ssm+mysql+jsp)

风定落花生&#xff0c;歌声逐流水&#xff0c;大家好我是风歌&#xff0c;混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的网络视频播放器。技术交流和部署相关看文章末尾&#xff01; 开发环境&#xff1a; 后端&#xff1a; 开发语言&#xff1a;Java 框架&a…

Arduino RP2040 两个CDC虚拟串口通讯

Arduino RP2040 两个CDC虚拟串口通讯 &#x1f3ac;通讯效果演示&#xff1a; &#x1f33f;基于Earle F. Philhower的固件开发平台&#xff1a; https://github.com/earlephilhower/arduino-pico&#x1f516;USB配置参考&#xff1a;https://arduino-pico.readthedocs.io/en/…

【算法基础:数学知识】4.1 质数

文章目录 质数例题列表866. 试除法判定质数&#xff08;质数的判定&#xff09;867. 分解质因数&#xff08;&#xff09;868. 筛质数埃氏筛欧氏筛 / 线性筛 相关链接 质数 定义&#xff1a;质数是指在大于1的自然数中&#xff0c;除了1和它本身以外不再有其他因数的自然数。 …