simple RNN pytorch代码实现

news2025/1/22 18:56:22

simple RNN pytorch代码实现

在写这篇博客之前,博主要说一件事情,网上的simple RNN代码很多都是错误的,博主的也是错误的,为什么呢?
因为simple RNN的梯度下降代码必须自己去写,simple RNN的梯度下降不能使用pytorch的默认机制,否则会直接出现梯度消失,博主做了很多实验,一开始一直以为是代码写错了,后面发现,simple RNN不能使用一般的梯度下降算法去做,必须使用随时间梯度下降算法去实现,也就是如果你要复现simple RNN需要自己去写梯度下降代码,不能直接搭建模型训练。不过呢,使用两三层的simple RNN还可以,因为梯度消失还不严重,这里,我们给出我们的代码:

#coding=gbk

import torch
from torch.autograd import Variable
import os
from torch.utils import data
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
sample_num=1000
sequense_num=10
input_length=10
train_de_test=0.8
hidden_size=10
num_epochs=100
batch=32
learning_rate=0.01
#torch.manual_seed(10)

x_data=[]

y_data=[]
for i in range(sample_num):
    if i%2==0:
        x_gene=torch.randint(0,5,(sequense_num,input_length))
     
        y_data.append(1)
   
        
    else:
         x_gene=torch.randint(6,10,(sequense_num,input_length))
        #  y=torch.sum(x_gene)

       #  y_data.append(1)
         y_data.append(0)
   
    x_data.append(x_gene)


x_data=torch.stack((x_data),0)

x_data=x_data.type(dtype=torch.float32)
y_data=torch.as_tensor(y_data)
print(x_data,y_data)
print(x_data.size())

    # 神经网络搭建

 
class sRNN(nn.Module):
    def __init__(self, sequense_num,input_length,hidden_size):
        
        super().__init__()
        self.sequense_num=sequense_num
        self.input_length=input_length
        self.hidden_size=hidden_size

        self.W = torch.nn.Parameter(data=torch.randn(input_length, hidden_size, requires_grad=True))    
        self.U = torch.nn.Parameter(data=torch.randn(input_length, hidden_size, requires_grad=True))   
        self.b = torch.nn.Parameter(data=torch.randn( 1,hidden_size, requires_grad=True))  
        self.V = torch.nn.Parameter(data=torch.randn(hidden_size,1 , requires_grad=True))  
        self.f=nn.Sigmoid()
    def forward(self,input):#d就是整个网络的输入
        hidden_state_pre=torch.zeros(1, hidden_size)
        for i in range(sequense_num):
           # print(torch.matmul(self.W,input[:,i,:]))
            z=torch.matmul(hidden_state_pre,self.U)+torch.matmul(input[:,i,:],self.W)+ self.b
            hidden_state_pre=F.relu(z)
      #  print(hidden_state_pre)
        y=self.f(torch.matmul(hidden_state_pre,self.V))
        return y
            
    def backward(self):
        pass
loss_fn = nn.BCELoss()

def sampling(sample_num):
   
    index_sequense=torch.randperm(sample_num)
    return index_sequense
def get_batch(index_sequense,X_data,Y_data,index,bacth):

    return X_data[index:index+bacth],Y_data[index:index+bacth]

srnn=sRNN(sequense_num,input_length,hidden_size)

loss_fn = nn.BCELoss()
index_sequense=sampling(sample_num)
optimizer = torch.optim.Adam(srnn.parameters(), lr=0.1)
co=0


for param_tensor in srnn.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
    print(param_tensor,'\t',srnn.state_dict()[param_tensor].size())
    print(srnn.state_dict()[param_tensor])
acc_list=[]

       
index=0
index_sequense=torch.randperm(sample_num)
loss_list=[]
for k in range(num_epochs):
            
                if index+batch>=sample_num-1:
                     index=0
                     index_sequense=torch.randperm(sample_num)

                x_batch,y_batch=get_batch(index_sequense,x_data,y_data,index,batch)
                y_batch=y_batch.type(dtype=torch.float32)
                y_batch=y_batch.reshape(batch,1)
           #     print(x_batch)
                predict=srnn(x_batch)
                co=torch.sum(torch.abs(y_batch-predict)<0.5)

                loss = loss_fn(predict,y_batch)
                index=index+batch
                optimizer.zero_grad()

                loss.backward()
                optimizer.step()
                for p in srnn.parameters():
               #    p.data.add_(p.grad.data, alpha=-learning_rate)
                  # print("p.data",p.data)
                   print("p.grad.data",p.grad.data)
              #  print(predict)
               # print(y_batch-predict)
            
                print("loss :",loss)

                print("accuracy :",co/batch)
                loss_list.append(loss)
                acc_list.append(co/batch)


              #except:
              #    index=0
              #    index_sequense=torch.randperm(sample_num)



                #optimizer.zero_grad()
                #loss.backward()








#
#print(x_batch,y_batch)
x_batch,y_batch=get_batch(index_sequense,x_data,y_data,0,batch)
print(x_batch.size())
print("#")
print(srnn(x_batch))
#y_batch=y_batch.type(dtype=torch.float32)


#train(x_data,y_data,num_epochs,batch)

epoch_list=list(range(num_epochs))
plt.plot(epoch_list,acc_list,label='adam')
plt.title("loss")
plt.legend()

plt.show()


plt.plot(epoch_list,loss_list,label='adam')
plt.title("loss")
plt.legend()

plt.show()



#print(srnn.parameters())

#print(type(srnn.state_dict()))  # 查看state_dict所返回的类型,是一个“顺序字典OrderedDict”


 
#for param_tensor in srnn.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
#    print(param_tensor,'\t',srnn.state_dict()[param_tensor].size())
#    print(srnn.state_dict()[param_tensor])

os.system("pause")

运行结果:
在这里插入图片描述
在这里插入图片描述

上面代码并不时完全正确的simple RNN代码,不过大家需要完善梯度下降算法就可以了。反向的去计算随时间的梯度就可以了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/751382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

巧用浮动布局、解决高度塌陷实例分享

问题 如图所示&#xff0c;这种效果该怎么实现呢&#xff1f; 面包屑导航和按钮一行两端显示面包屑或编辑栏超出宽度则自动另行显示 实现 采用浮动&#xff0c;绿色块左浮&#xff0c;蓝色块右浮&#xff0c;利用浮动特性实现宽度超出另一行显示的效果&#xff0c;并是动态的…

ModaHub魔搭社区:什么是非结构化数据?

目录 概览 区分结构化、半结构化和非结构化数据 结构化数据示例 欢迎来到向量数据库 101 系列教程。 概览 这是向量数据库 101 系列教程第一课,主要向大家介绍一下非结构化数据。 现在我们每天都会产生新的数据,这无疑是全球一体化和全球经济的关键动力。从腕部佩戴的…

【NLP】国外新动态--LLM模型

一、说明 NLP走势如何&#xff1f;这是关于在实践中使用大型语言模型&#xff08;LLM&#xff09;的系列文章中的一篇文章。在这里&#xff0c;我将介绍LLM&#xff0c;并介绍使用它们的3个级别。未来的文章将探讨LLM的实际方面&#xff0c;例如如何使用OpenAI的公共API&#x…

面试题更新之-css中link和@import的区别

文章目录 导文link是什么&#xff1f;import是什么&#xff1f;css中link和import的区别 导文 面试题更新之-css中link和import的区别 link是什么&#xff1f; CSS Link是用于将外部CSS文件链接到HTML文档中的HTML标签。通过使用CSS Link标签&#xff0c;可以将外部的CSS样式表…

游戏测试与策划的那些事儿

作为一个游戏测试员&#xff0c;和程序、前端、策划之间的沟通交流在所难免。今天就来吐槽一下子啦~ 作为游戏测试的核心机密&#xff0c;可不能被他们知道我们在背后吐槽啦~ 游戏测试&#xff1a;XXX&#xff0c;刚测完这数据怎么和之前的不一样了&#xff1f; 策划&#xff1…

Python分布式任务队列Celery

一、分布式任务队列Celery介绍 Python celery是一个基于Python的分布式任务队列&#xff0c;主要用于任务的异步执行、定时调度和分布式处理。它采用了生产者/消费者模式&#xff0c;通过消息中间件实现多个工作者进程之间的协作。 Python celery的架构主要包括以下组件&…

new和不用new调用构造函数,有什么区别?

new和不用new的构造函数&#xff0c;有什么区别&#xff1f; 下面从有return和无return,如果在有return的情况下&#xff0c;return原始类型数据和return引用应用类型数据等几方面进行论述&#x1f44d;&#x1f44d;&#x1f44d; 区别1&#xff1a;当没有return时 functio…

Bus消息总线(在Spring Cloud整合Bus(idea19版本)

Bus消息总线 所谓消息总线Bus&#xff0c;即我们经常会使用MQ消息代理构建一个共用的Topic&#xff0c;通过这个Topic连接各个微服务实例,MQ广播的消息会被所有在注册中心的微服务实例监听和消费。换言之就是通过一个主题连接各个微服务&#xff0c;打通脉络 Spring Cloud Bus …

FastDDS 源码剖析:DataWriter分析

目录 DataWriter分析 DataWriter 类分析 DataWriterImpl 类分析 关键函数分析 DataWriter分析 DataWriter 类分析 DataWriter 类是 Fast DDS 库中的一个重要类&#xff0c;它用于实现 DDS&#xff08;Data Distribution Service&#xff09;发布-订阅通信模型中的数据写入…

Git #01 操作记录

本篇内容 0. 前期配置1. 仓库1.1 上传本地代码到远程仓库 0. 前期配置 请提前配置好 git 的全局用户名&#xff1a; # xin&#xff1a;账号名 $ git config --global user.name "xin" # xin163.com&#xff1a;账号绑定的邮箱地址 $ git config --global user.emai…

单片机能否替代PLC实现控制和自动化系统?

是的&#xff0c;单片机可以在某些情况下替代PLC&#xff0c;但在其他情况下可能并不适用。以下是对这个问题的详细解释&#xff1a; 我这里刚好有嵌入式、单片机、plc的资料需要可以私我或在评论区扣个6 灵活性和可编程性&#xff1a;PLC相对于单片机来说更具有灵活性和可编…

DolphinScheduler minio(S3支持)开启资源中心

DolphinScheduler 如果是在3.0.5 及之前的版本&#xff0c;没办法支持 S3 的协议的 当你按照文档配置之后&#xff0c;运行启动之后&#xff0c;在master 和 worker 节点&#xff0c;都会出现 缺包的依赖问题。 那这个问题在什么版本修复了呢&#xff1f; 3.0.6... 那 3.0.6 …

每个前端开发者都应知道的10个实用网站

微信搜索 【大迁世界】, 我会第一时间和你分享前端行业趋势&#xff0c;学习途径等等。 本文 GitHub https://github.com/qq449245884/xiaozhi 已收录&#xff0c;有一线大厂面试完整考点、资料以及我的系列文章。 快来免费体验ChatGpt plus版本的&#xff0c;我们出的钱 体验地…

sprinboot企业客户信息反馈平台

企业客户信息反馈平台的开发运用java技术&#xff0c;MIS的总体思想&#xff0c;以及MYSQL等技术的支持下共同完成了该平台的开发&#xff0c;实现了企业客户信息反馈管理的信息化&#xff0c;使用户体验到更优秀的企业客户信息反馈管理&#xff0c;管理员管理操作将更加方便&a…

canal番外篇-otter

前置知识点 主从复制binlogcanal正则dockerjava 前置工具 dockerotter-all 场景描述&#xff08;增量同步&#xff09; 目前项目中使用的是mysql5.5&#xff0c;计划升级为mysql8.1&#xff0c;版本跨度较大&#xff0c;市面上可靠工具选择较少。otter符合预期&#xff0c…

3Ds max入门教程:为男性角色创建服装T 恤

推荐&#xff1a; NSDT场景编辑器助你快速搭建可二次开发的3D应用场景 3ds Max 角色服装教程 在本 3ds Max 教程中&#xff0c;我们将为角色模型创建一个简单的 T 恤。我们提供了一个“human_figure.obj”文件供您导入模型。因此&#xff0c;本教程将重点介绍的是创建服装&…

【VTK】VTK 显示小球例子,在 Windows 上使用 Visual Studio 配合 Qt 构建 VTK

知识不是单独的&#xff0c;一定是成体系的。更多我的个人总结和相关经验可查阅这个专栏&#xff1a;Visual Studio。 关于更多此例子的资料&#xff0c;可以参考&#xff1a;【Visual Studio】在 Windows 上使用 Visual Studio 配合 Qt 构建 VTK。 文章目录 版本环境VTKTest.…

【机器人模拟-01】使用URDF在中创建模拟移动机器人

一、说明 在本教程中,我将向您展示如何使用通用机器人描述格式 (URDF)(机器人建模的标准 ROS 格式)创建模拟移动机器人。 机器人专家喜欢在构建机器人之前对其进行模拟,以测试不同的算法。您可以想象,使用物理机器人犯错的成本可能很高(例如,将移动机器人高速…

SPSS方差分析

参考文章 导入准备好的数据 选择分析方法 选择参数 选择对比&#xff0c;把组别放入因子框中&#xff0c;把红细胞增加数放进因变量列表 勾选“多项式”&#xff0c;等级取默认“线性” &#xff0c;继续 接着点击“事后比较”&#xff0c;弹出对话框&#xff0c;勾选“LSD” …

华为OD机试真题 JavaScript 实现【分糖果】【2022Q2 200分】,附详细解题思路

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、JavaScript算法源码六、效果展示 专栏导读 本专栏收录于《华为OD机试&#xff08;JavaScript&#xff09;真题&#xff08;A卷B卷&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;每一题都…