机器学习入门--循环神经网络原理与实践

news2024/12/24 9:15:30

循环神经网络

循环神经网络(RNN)是一种在序列数据上表现出色的人工神经网络。相比于传统前馈神经网络,RNN更加适合处理时间序列数据,如音频信号、自然语言和股票价格等。本文将介绍RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码。

数学原理

RNN是一种带有循环结构的神经网络,其在处理序列数据时将前一次的输出作为当前输入的一部分。这使得RNN能够记住先前的状态和信息,并且在处理长期依赖关系时表现出色。

RNN的基本公式可以表示为:

h t = f ( W h h h t − 1 + W x h x t ) h_t = f(W_{hh}h_{t-1} + W_{xh}x_t) ht=f(Whhht1+Wxhxt)

其中 h t h_t ht是RNN在时间步 t t t的隐藏状态, f f f是激活函数, W h h W_{hh} Whh是隐藏状态的权重矩阵, h t − 1 h_{t-1} ht1是上一次的隐藏状态, W x h W_{xh} Wxh是输入 x t x_t xt和隐藏状态 h t h_t ht之间的权重矩阵, x t x_t xt是时间步 t t t的输入。

在RNN的训练过程中,我们需要使用反向传播算法计算梯度并更新权重。由于RNN具有时间上的依赖关系,每一步的梯度都取决于前一步的梯度,这意味着我们需要使用反向传播算法的变体——反向传播通过时间(BPTT)算法来计算梯度。

代码实现

我们将使用PyTorch和Scikit-Learn数据集实现一个简单的RNN模型,用于预测时间序列数据。以下是代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# 加载数据集
data = load_boston()
X = data.data
y = data.target

# 数据标准化
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)

# 转换为PyTorch张量,并增加时间步维度
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)

# 定义RNN模型
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])
        return out

# 创建模型实例
input_size = X.shape[2]  # 更新input_size的值
hidden_size = 32
output_size = 1
model = SimpleRNN(input_size, hidden_size, output_size)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 启用异常检测
torch.autograd.set_detect_anomaly(True)

# 训练模型
num_epochs = 10000
# 记录损失
loss_list = []

for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    if (epoch+1) % 100 == 0:
        loss_list.append(loss.item())
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# 关闭异常检测
torch.autograd.set_detect_anomaly(False)

# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of RNN Training')
plt.show()
plt.savefig('Loss_of_RNN_Training.png')

# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)  # 假设使用第一个数据点进行预测
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码实现了一个简单的循环神经网络(RNN)模型来预测波士顿房价,并可视化训练过程中损失的变化。代码首先加载并标准化了波士顿房价数据集,然后定义了一个包含RNN层和全连接层的SimpleRNN模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制训练过程中损失的变化曲线(如下图所示)。最后,使用训练好的模型对新的数据点进行预测,并输出预测值。这段代码可以为初学者提供一个实现RNN模型的参考,并通过可视化训练过程中的损失曲线来帮助理解模型的性能。
RNN 损失曲线

总结

本文介绍了RNN的基本数学原理、使用PyTorch和Scikit-Learn数据集实现的代码,以及如何解读代码并总结。RNN是一种在序列数据上表现出色的神经网络,常用于处理时间序列数据,如音频信号、自然语言和股票价格等。我们可以使用PyTorch和Scikit-Learn数据集来实现一个简单的RNN模型,并用它来预测未知的时间序列数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1452143.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

上位机图像处理和嵌入式模块部署(图像项目处理过程)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 对于一般的图像项目来说,图像处理只是工作当中的一部分。在整个项目处理的过程中有很多的内容需要处理,比如说了解需求、评…

信息安全技术基础知识

一、考点分布 信息安全基础(※※)信息加密解密技术(※※※)密钥管理技术(※※)访问控制及数字签名技术(※※※)信息安全的保障体系 二、信息安全基础 信息安全包括5个基本要素&#…

【COMP337 LEC3】

LEC 3 Mathematical Preliminaries Common Discrete Probability Distributions 1. Bernoulli distribution : 伯努利分布 models binary outcomes (coin flip). 模型二进制结果 P ( X head ) p and P ( X tail ) 1 − p 2. Generalised Bernoulli distribution…

牛客网SQL进阶123:高难度试卷的得分的截断平均值

官网链接: SQL类别高难度试卷得分的截断平均值_牛客题霸_牛客网牛客的运营同学想要查看大家在SQL类别中高难度试卷的得分情况。 请你帮她从exam_。题目来自【牛客题霸】https://www.nowcoder.com/practice/a690f76a718242fd80757115d305be45?tpId240&tqId2180…

《PCI Express体系结构导读》随记 —— 第II篇 第13章 PCI总线与虚拟化技术(6)

接前一篇文章:《PCI Express体系结构导读》随记 —— 第II篇 第13章 PCI总线与虚拟化技术(5) 13.2 ATS(Address Translation Services) 单纯使用IOMMU并不能充分发挥处理器系统的效率,从图13-2中可以发现&…

WordPress站点如何实现发布文章即主动推送到百度快速收录和普通收录?

我们在WordPress后台成功发布文章之后,如果靠搜索引擎来抓取的话,可能会比较慢,所以十分有必要将我们成功发布的文章马上提交到百度、必应等搜索引擎中。下面boke112百科就跟大家说一说WordPress站点如何实现发布文章即主动推送到百度快速收录…

基于SpringBoot的教学管理app的开发65449-计算机毕业设计项目选题推荐(附源码)

摘 要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对教学管理等问题,对其进行研究分…

Electron实战之进程间通信

进程间通信(IPC)并非仅限于 Electron,而是源自甚至早于 Unix 诞生的概念。尽管“进程间通信”这个术语的确创造于何时并不清楚,但将数据传递给另一个程序或进程的理念可以追溯至 1964 年,当时 Douglas McIlroy 在 Unix…

centos中docker操作+安装配置django+mysql5.7并使用simpleui美化管理后台

一、安装docker 确保系统是CentOS 7并且内核版本高于3.10,可以通过uname -r命令查看内核版本。 更新系统软件包到最新版本,可以使用命令yum update -y。 安装必要的软件包,包括yum-utils、device-mapper-persistent-data和lvm2。使用命令yum install -y yum-utils devic…

【51单片机】如何【手搓】定时器寄存器配置【低8位TL0(low)】和【高8位TH0(high)】

前言 大家好吖,欢迎来到 YY 滴单片机系列 ,热烈欢迎! 本章主要内容面向接触过单片机的老铁 本文是【【51单片机】从零开始手把手带你【查手册】配置定时器,并完成小项目(定时器&中断的应用)】博…

C++,stl,常用排序算法,常用拷贝和替换算法

目录 1.常用排序算法 sort random_shuffle merge reverse 2.常用拷贝和替换算法 copy replace replace_if swap 1.常用排序算法 sort 默认从小到大排序 #include<bits/stdc.h> using namespace std;int main() {vector<int> v;v.push_back(1);v.push_ba…

cpp杂项知识点(一)

大小端验证 代码如下&#xff1a; #include <iostream> #include <stdio.h> #include <memory> #include <string.h> #include <string>using namespace std;void hexdump(void *pSrc, int len ) {unsigned char *line;int i;int thisline;in…

Java的集合框架和泛型

文章目录 集合框架什么是集合框架类和接口总览 集合框架的重要性背后所涉及的数据结构以及算法什么是数据结构容器背后对应的数据结构什么是算法 包装类基本数据类型和对应的包装类装箱和拆箱自动装箱和自动拆箱 泛型什么是泛型引出泛型语法泛型类泛型的上界(没有下界)泛型方法…

Vue2学习第三天

Vue2 学习第三天 1. 计算属性 computed 计算属性实现 定义&#xff1a;要用的属性不存在&#xff0c;要通过已有属性计算得来。 原理&#xff1a;底层借助了Objcet.defineproperty方法提供的getter和setter。 get函数什么时候执行&#xff1f; 初次读取时会执行一次。当依赖…

知识图谱:py2neo将csv文件导入neo4j

文章目录 安装py2neo创建节点-连线关系图导入csv文件删除重复节点并连接边 安装py2neo 安装python中的neo4j操作库&#xff1a;pip install py2neo 安装py2neo后我们可以使用其中的函数对neo4j进行操作。 图数据库Neo4j中最重要的就是结点和边&#xff08;关系&#xff09;&a…

数字经济政策 | ZF工作报告-60个文本词频

根据各省政府工作报告&#xff0c;参考金灿阳(2022)和陶长琪(2022)&#xff0c;借助Python软件&#xff0c;统计数字经济相关的关键词词频&#xff0c;分别记为数字经济政策词频A、数字经济政策词频B A文献参考 B文献参考 年度趋势 一、数据介绍 数据名称&#xff1a; 政府工…

OpenAI首个文生视频模型亮相,你觉得咋样?

2月16日凌晨&#xff0c;OpenAI再次扔出一枚深水炸弹&#xff0c;发布了首个文生视频模型Sora。据介绍&#xff0c;Sora可以直接输出长达60秒的视频&#xff0c;并且包含高度细致的背景、复杂的多角度镜头&#xff0c;以及富有情感的多个角色。 目前官网上已经更新了48个视频d…

QGIS004:【08图层工具箱】-导出到电子表格、提取图层范围

摘要&#xff1a;QGIS图层工具箱常用工具有导出到电子表格、提取图层范围等选项&#xff0c;本文介绍各选项的基本操作。 实验数据&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1ZK4_ShrQ5BsbyWfJ6fVW4A?pwdpiap 提取码&#xff1a;piap 一、导出到电子表格 工具…

集团企业大数据应用:突破痛点,释放数据价值

在数字经济日益崛起的背景下&#xff0c;集团企业以其管理范围广泛、业务领域多元化和分支机构复杂化的特性&#xff0c;在市场竞争中扮演着重要角色。为了维持和提升这种竞争力&#xff0c;大数据应用成为了集团企业不可或缺的战略工具。然而&#xff0c;在实际应用中&#xf…

图表示学习 Graph Representation Learning chapter1 引言

图表示学习 Graph Representation Learning chapter1 引言 前言1.1图的定义1.1.1多关系图1.1.2特征信息 1.2机器学习在图中的应用1.2.1 节点分类1.2.2 关系预测1.2.3 聚类和组织检测1.2.4 图分类、回归、聚类 前言 虽然我并不研究图神经网络&#xff0c;但是我认为图高效的表示…