基于Python的机器学习系列(32):PyTorch - 循环神经网络(RNN)

news2025/1/3 11:24:28

        在本篇文章中,我们将探索循环神经网络(RNN),这是一种特别适用于时间序列数据或文本数据的神经网络模型。在RNN中,当前的输出不仅取决于当前的输入,还受到前一步输出的影响,从而能够捕捉序列数据中的时间依赖性。

1. 循环神经网络(RNN)基础

        RNN的基本单元可以通过以下公式表示:

2. 使用LSTM预测苹果股票价格     

        在这部分中,我们将使用LSTM(长短期记忆网络)来预测苹果股票的未来价格。LSTM是一种改进版的RNN,专门设计用来解决标准RNN在处理长时间序列时遇到的梯度消失或爆炸问题。

步骤概述

  1. 数据加载和预处理:我们将加载包含苹果股票历史数据的CSV文件,并进行必要的预处理,如归一化和时间序列拆分。
  2. 构建LSTM模型:定义一个LSTM模型来处理我们的时间序列数据,包括输入层、LSTM层和全连接层。
  3. 训练模型:使用过去的30天数据来训练模型,并进行预测。
  4. 评估和预测:使用模型对未来的价格进行预测,并评估模型性能。

示例代码

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用的设备: {device}")

# 1. 加载数据
df = pd.read_csv("data/appl_1980_2014.csv")
df.Date = pd.to_datetime(df.Date)
df = df.set_index('Date')

# 数据归一化
data = df[['Close']].values
data = (data - np.min(data)) / (np.max(data) - np.min(data))

# 创建时间序列数据
def create_sequences(data, seq_length):
    sequences = []
    labels = []
    for i in range(len(data) - seq_length):
        sequences.append(data[i:i+seq_length])
        labels.append(data[i+seq_length])
    return np.array(sequences), np.array(labels)

seq_length = 30
X, y = create_sequences(data, seq_length)

# 划分训练集和测试集
train_size = int(len(X) * 0.8)
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.FloatTensor(y_train).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.FloatTensor(y_test).to(device)

# 2. 定义LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(1, x.size(0), self.hidden_size).to(device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

input_size = 1
hidden_size = 64
output_size = 1

model = LSTM(input_size, hidden_size, output_size).to(device)

# 3. 训练模型
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    outputs = model(X_train)
    optimizer.zero_grad()
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 4. 评估模型
model.eval()
with torch.no_grad():
    predictions = model(X_test)
    plt.plot(y_test.cpu().numpy(), label='真实值')
    plt.plot(predictions.cpu().numpy(), label='预测值')
    plt.legend()
    plt.show()

结语

        通过使用LSTM,我们能够有效地捕捉时间序列数据中的长期依赖性,从而提高预测的准确性。与传统的RNN相比,LSTM在处理复杂序列数据时表现出色。希望通过本篇文章,大家对LSTM模型有了更深入的理解,并能够在自己的项目中应用这一强大的工具。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2120440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 入门指南:JVM(Java虚拟机)—— Java 类文件结构

文章目录 字节码JVM 与字节码字节码的生成过程 Class 文件结构魔数(Magic Number)Class 文件版本号(Minor&Major Version)常量池(Constant Pool)访问标志(Access Flags)前类(This Class&…

Pygame中Sprite类实现多帧动画3-3

4 使用自定义类MySprite 使用自定义类MySprite实现多帧动画的步骤是首先创建MySprite类的实例,之后使用相关函数对该实例进行操作。 4.1 创建MySprite类的实例 创建MySprite类的实例的代码如图12所示。 图12 创建MySprite类的实例的代码 其中,变量dr…

2024年CCPC网络赛A题题解 —— 军训Ⅰ(gym105336A)

个人认为很唐的一道题,考虑到不少人可能懒得写,我这里给大家发个代码叭,还有一点点题解(因为真的不是很难)。这是题面: 然后我来讲讲怎么做,不觉得会有多少人题目意思都理解不了叭?这…

Javaweb项目-调用接口-如何在服务器端跳转网页后显示并弹出对话框代码

Webapp 项目中在java包下新建一个服务端类 使用JOptionPane框架组件 调用showMessageDialog的方法实现 四个参数null,"这是一个信息对话框","信息",JOptionPane.INFORMATION_MESSAGE 还有确认对话框的代码showConfirmDialog package servlet;import java…

k9s 是什么?有什么功能?

目录 k9s 是什么? 有什么功能? 手动安装 K9s(Windwos) 将 k9s.exe 添加到系统 PATH 启动 K9s k9s 是什么? K9s 是一个命令行工具,用于通过一个图形化的终端界面(类似于图形化用户界面但在命…

【Linux】常用的命令

文章目录 lsls -l / touchcdpwdcatechovim打开文件编辑内容保存退出 mkdirrmmvcpmangreppsnetstat总结 : ls ls > list 列出当前目录下都有哪些内容(文件/目录) 直接输入 ls,是查看当前目录的情况;输入 ls/ 就是看…

检查你的防病毒软件是否可以阻止这 5 个测试恶意软件文件

从网络安全专家到你,每个人都知道你应该使用防病毒软件来保护你的电脑免受黑客、病毒和其他类型的网络威胁。 但即使你这些年来一直在努力使用防病毒程序,你怎么知道它真的有效呢? 安全专家已经想到了这一点,并创建了几种类型的…

TMS320F28335芯片及使用介绍

1、简介 CPU性能的好坏不仅取决于主频大小,还需要看其整体架构集成性能、运算能力与指令体系。TMS320C2000系列DSP集微控制器和高性能 DSP 的特点于一身,具有强大的控制和信号处理能力,能够实现复杂的控制算法。TMS320C2000 系列DSP 片上整合了Flash存储器、快速的AD转换器…

基于微信小程序+Java+SSM+Vue+MySQL的付费自习室预订管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于微信小程序JavaSSMVueMySQL的付费自习室预订管理系统【…

【CMake编译报错小复盘】CMAKE_CUDA_ARCHITECTURES,CMake version,GCC version问题

今天在写大模型量化推理框架时遇到了一些编译上的错误,简单复盘一下问题和解决方案: 问题1:CMAKE_CUDA_ARCHITECTURES 报错信息: CMake Error: CMAKE_CUDA_ARCHITECTURES must be non-empty if set cmake和cuda相关的报错通常都…

linux进程间通信——进程间通信概念、最基本通信——管道文件

前言: 本节内容将要讲解进程间通信。 之前我们说过进程之间是相互独立的, 但是,相互独立并不代表不能进行数据的输送。就好比我和你是相互独立的, 但是我们可以成为朋友, 可以互赠礼物。 而我们一般而言的,…

【C语言】归并排序递归和非递归——动图演示

目录 一、归并排序思想1.1 基本思想1.2 大体思路 二、实现归并排序(递归)三、实现归并排序(非递归)3.1 实现思路:3.2 越界处理3.3 时间复杂度和空间复杂度 总结 一、归并排序思想 1.1 基本思想 归并排序(M…

RTMP和WebRTC使用场景有哪些差别?

省流版先说结论 直播领域,RTMP和WebRTC各有优势。如果直播场景对延迟有一定要求,但更注重稳定性和兼容性,那么RTMP可能是一个更好的选择。如果直播场景需要极低的延迟,并且用户主要在浏览器环境下进行观看和互动,那么…

Leangoo敏捷工具在缺陷跟踪(BUG)管理中的高效应用

在开发过程中,缺陷(BUG)管理一直是项目管理中的一个关键环节。及时发现并修复BUG,不仅能够提高产品质量,还能有效提升团队的工作效率和用户满意度。 在敏捷开发中,快速迭代和频繁交付的特点使得缺陷管理的…

Servlet的特性(一)

Servlet的主要用途: 接受、处理来自浏览器端(BS架构中的B端)的请求和用户输入 响应来自数据库或者服务端(BS架构中的S端)产生的数据到浏览器端,动态构建网页。 手动实现Servlet小程序 实现步骤 自定义一个类型,实现Servlet接口或者继承Ht…

Spring Boot 集成 Redisson 实现消息队列

包含组件内容 RedisQueue:消息队列监听标识RedisQueueInit:Redis队列监听器RedisQueueListener:Redis消息队列监听实现RedisQueueService:Redis消息队列服务工具 代码实现 RedisQueue import java.lang.annotation.ElementTyp…

GD32E230 RTC报警中断功能使用

GD32E230 RTC报警中断使用 GD32E230 RTC时钟源有3个,一个是内部RC振动器产生的40KHz作为时钟源,或者是有外部32768Hz晶振.,或者外部高速时钟晶振分频作为时钟源。 🔖个人认为最难理解难点的就是有关RTC时钟异步预分频和同步预分频的计算。在对…

C++第二节入门 - 缺省参数和函数重载

一、缺省参数 1、概念 缺省参数是声明或定义函数时为函数的参数指定一个缺省值。 在调用该函数的时候&#xff0c;如果没有指定实参则采用该形参的缺省值&#xff0c;否则使用指定的实参&#xff01; #include<iostream> using namespace std;void Func(int a 0) {c…

2024 水博会,国信华源登场,数智创新助力水利高质量发展

9月4日-6日&#xff0c;由中国水利学会和中国水利工程协会共同主办的2024中国水博览会暨第十九届中国&#xff08;国际&#xff09;水务创新技术交流会在重庆国际博览中心召开。 本次水博会以“展水利前沿新技术 览新质生产力场景”为主题&#xff0c;国信华源携最新智能监测预…

【佳学基因检测】如何升级一个不再维护的软件包中的PHP代码?

如何升级一个不再维护的软件包中的PHP代码&#xff1f; 为什么要升级一个不再维护但是仍在使用的软件包中的PHP代码&#xff1f; 升级一个不再维护但仍在使用的软件包中的 PHP 代码是一个复杂但重要的过程。虽然这些软件包可能已经不再活跃地维护或更新&#xff0c;但升级其代…