最简单的RNN预测股票收盘价

news2025/1/11 20:45:45

1.首先,导入必要的库:

import torch
import torch.nn as nn
import numpy as np

2.准备数据。需要准备好包含历史股票收盘价的一维时间序列数据。在这个例子中,我们将使用NumPy模拟一些示例数据

# 示例的股票收盘价时间序列数据
# 假设数据点表示 2021.1.1 到 2021.1.30 每日的股票收盘价
closing_prices = np.array([100.0, 101.2, 102.5, 101.8, 103.0, 102.7, 103.5, 104.2, 104.5, 105.0,
                           104.8, 104.9, 105.2, 106.0, 107.0, 107.5, 107.8, 108.2, 109.0, 109.5,
                           109.8, 110.2, 110.5, 111.0, 112.0, 112.5, 112.8, 113.0, 113.2, 113.5])

# 将数据转换为PyTorch张量
closing_prices = torch.tensor(closing_prices, dtype=torch.float32)

3创建序列数据。为了使用RNN模型,你需要将时间序列数据分割成输入序列和目标序列。在这个示例中,我们将使用前N天的数据来预测下一天的股票收盘价。你可以调整N的值来控制输入序列的长度。

# 定义输入序列的长度
sequence_length = 5  # 使用前5天的数据来预测下一天的收盘价

# 创建输入序列和目标序列
def create_sequences(data, sequence_length):
    sequences = []
    targets = []
    for i in range(len(data) - sequence_length):
        seq = data[i:i + sequence_length]
        target = data[i + sequence_length]
        sequences.append(seq)
        targets.append(target)
    return torch.stack(sequences), torch.stack(targets)

# 创建序列数据
input_sequences, target_sequences = create_sequences(closing_prices, sequence_length)

4定义RNN模型。在这里,我们使用一个简单的RNN模型,它有一个RNN层和一个全连接层:

class StockPredictionRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(StockPredictionRNN, self).__init__()
        
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out, _ = self.rnn(x)
        out = self.fc(out[:, -1, :])
        return out

# 定义模型参数
input_size = 1  # 输入特征的维度,这里是股票收盘价
hidden_size = 64  # 隐藏层的维度
output_size = 1  # 输出特征的维度,这里是预测的股票收盘价
num_layers = 1  # RNN的层数

# 创建模型
model = StockPredictionRNN(input_size, hidden_size, output_size, num_layers)

5定义损失函数和优化器

criterion = nn.MSELoss()  # 使用均方误差损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

6训练模型。你需要将输入序列传递给模型,计算损失并反向传播进行优化。这里只提供一个简单的示例,实际训练通常需要更多的数据和迭代次数

num_epochs = 4000

for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(input_sequences.unsqueeze(2))  # 添加额外的维度以匹配模型的输入要求
    loss = criterion(outputs, target_sequences)
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

7使用训练好的模型进行预测:

# 准备输入数据来预测未来5天的收盘价
input_data = closing_prices[-sequence_length:].unsqueeze(0).unsqueeze(2)

# 使用模型进行预测
predicted_prices = model(input_data).squeeze().detach().numpy()

print("预测的未来1天股票收盘价:", predicted_prices)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1049286.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux面试题汇总1

MySQL数据库 1、MySQL和Oracle的区别 1.Oracle是大型数据库,而MySQL是中小型数据库。但是MySQL是开源的,但是Oracle是收费的,而且比较贵。 2. Oracle的内存占有量非常大,而mysql非常小 3. MySQL支持主键自增长,指定主…

MySQL explain SQL分析工具详解与最佳实践

目录 一、explain工具介绍二、添加示例表和数据用于后续演示三、explain中的列3.1、id列3.2、select_type列3.3、table列3.4、partitions列3.5、type列NULLsystemconsteq_refrefrangeindexALL 3.6、possible_keys列3.7、key列3.8、key_len列3.9、ref列3.10、rows列3.11、filter…

《视觉 SLAM 十四讲》V2 第 11 讲 回环检测【消除累积误差】

待做: 习题整理 相关文献【新的综述】等 P283 文章目录 11.2 词袋 模型11.3.2 Code: 创建字典11.4.2 Code: 相似度 计算训练 自己的字典 报错 习题√ 题1√ 题2题3 DBoW3库题4题5 基于 词袋 的外观式 回环检测 SLAM主体(前端后端): 估计相机…

图片处理后再保存为图片到文件夹中,文件夹下文件名不变改格式保存

首先读取图片; 然后处理,得到cv:Mat类型; 对cv:Mat类型图片写入文件夹,保存到指定路径。 像raw图等不能直接读取显示,需要先进行解码,转换为可以显示的图片。 下面举例读入本来可以显示的图。以下代码加…

哈弗猛龙实力登场,「方盒子猛改派对」掀起越野改装新热潮

9月22日-24日,哈弗猛龙“方盒子猛改派对”在北京751 D-PARK 火车头广场成功举办。活动现场盛况空前,不仅有官方展出的11台不同风格的猛改车型,更吸引了不同领域的博主大咖及越野达人前来参与活动。 与此同时,哈弗猛龙用户大定权益…

【EI会议征稿】第三届信号处理与通信技术国际学术会议(SPCT 2023)

第三届信号处理与通信技术国际学术会议(SPCT 2023) 2023 3rd International Conference on Signal Processing and Communication Technology 第三届信号处理与通信技术国际学术会议(SPCT 2023)将于2023年12月1-3日在长春召开。S…

【AIPOD案例操作教程】斜流风扇轮毂优化

AIPOD是由天洑软件自主研发的一款通用的智能优化设计软件,致力于解决能耗更少、成本更低、重量更轻、散热更好、速度更快等目标的工程设计寻优问题。针对工业设计领域的自动化程度低、数值模拟计算成本高等痛点,基于人工智能技术、自研先进的智能代理学习…

MySQL存储引擎以及InnoDB、MyISAM、Memory特点介绍

存储引擎介绍和基本使用 基本介绍: 存储引擎是数据库的核心,存储引擎就是存储数据、建立索引、更新/查询数据等技术的实现方式 。存储引擎是基于表的,而不是基于库的,所以存储引擎也可被称为表类型。我们可以在创建表的时候&…

U盘植马之基于arduino的badusb实现及思考

引言 曾经有这么一段传说,在某次攻防演练时,某攻击队准备了一口袋U盘前往了目标单位的工作园区,在园区围墙外停下了脚步,然后开始不停扔U盘进去,最后发现有大量的“猎奇者”上线。 U盘植马是常见的近源渗透方式之一&am…

若依不分离+Thymeleaf select选中多个回显

项目中遇到的场景&#xff0c;亲测实用 表单添加时&#xff0c;select选中多个&#xff0c;编辑表单时&#xff0c;select多选回显&#xff0c;如图 代码&#xff1a; // 新增代码 <label class"col-sm-3 control-label">通道&#xff1a;</label><…

再学C++ | std::set 的原理

std::set 是C标准库中的容器之一&#xff0c;它基于红黑树实现。std::set 利用红黑树的特性来实现有序的插入、查找和删除操作&#xff0c;并且具有较好的平均和最坏情况下的时间复杂度。 当向 std::set 插入元素时&#xff0c;它会按照特定的比较函数&#xff08;bool less<…

软件可靠性基础

软件可靠性基础 软件可靠性基本概念串并联系统可靠性计算软件可靠性测试软件可靠性建模软件可靠性管理软件可靠性设计容错&#xff0c;检错的技术 选择题考基本概念&#xff08;MTBF&#xff09;&#xff0c;很少考 非重点 软件可靠性基本概念 这个章节中&#xff0c;第一个…

Leetcode算法题练习(一)

目录 一、前言 二、移动零 三、复写零 四、快乐数 五、电话号码的字母组合 六、字符串相加 一、前言 大家好&#xff0c;我是dbln&#xff0c;从本篇文章开始我就会记录我在练习算法题时的思路和想法。如果有错误&#xff0c;还请大家指出&#xff0c;帮助我进步。谢谢&…

2023-9-27 JZ55 二叉树的深度

题目链接&#xff1a;二叉树的深度 import java.util.*; /** public class TreeNode {int val 0;TreeNode left null;TreeNode right null;public TreeNode(int val) {this.val val;}} */ public class Solution {public int TreeDepth(TreeNode root) {if(root null) ret…

续航605km,价格 11.77 万起带激光雷达,你卷我也卷

9 月 21 日&#xff0c;睿蓝 7 正式上市&#xff0c;新车提供 6 款车型&#xff0c;售价区间 11.77-17.37 万元。 权益方面&#xff0c;提供 701 元订金抵 2000 元车款、2000 元选装基金、终身 24 小时救援服务、10 万 3 年 0 息金融政策、3000 元置换/ 1000 元增购补贴、6 年/…

【Java 进阶篇】MySQL主键约束详解

MySQL是一个强大的关系型数据库管理系统&#xff0c;用于存储和管理大量数据。在数据库中&#xff0c;主键约束是一项非常重要的概念&#xff0c;它有助于确保数据的完整性和唯一性。本文将详细介绍MySQL主键约束&#xff0c;包括什么是主键、为什么需要主键、如何创建主键以及…

自增自减运算符i++与++i的区别

自增自减运算符用作前缀与用作后缀时略有不同。 i和i的区别&#xff1a; 1、i 返回原来的值&#xff0c;i 返回加1后的值。&#xff08; a i 是先给 a 赋值&#xff0c;然后 i 再自增&#xff1b;a i是 i 先自增&#xff0c;然后给 a 赋值。&#xff09; #include<iost…

(2023|ICLR,检索引导,交叉引导,EntityDrawBench)Re-Imagen:检索增强的文本到图像生成器

Re-Imagen: Retrieval-augmented text-to-image generator 公众号&#xff1a;EDPJ&#xff08;添加 VX&#xff1a;CV_EDPJ 或直接进 Q 交流群&#xff1a;922230617 获取资料&#xff09; 目录 0. 摘要 1. 简介 2. 相关工作 3. 模型 3.1 预备知识 3.2 用多模态知识…

msvcp140.dll丢失的解决方法与msvcp140.dll是什么东西详细解析

在使用电脑时&#xff0c;可能会遇到打开软件时提示“找不到 msvcp140.dll&#xff0c;无法继续执行代码”的问题。这通常意味着你的计算机上缺少 Microsoft Visual C Redistributable 的运行时库&#xff0c;或者该库的版本不正确。下面是我找了几天的修复方法&#xff0c;今天…

PBR的应用

项目拓扑与项目需求 项目需求&#xff1a;某企业网络拥有三个出口&#xff0c;分别使用AR1、AR2、AR3链接运营商网络。其中AR1为万兆出口&#xff0c;而AR2、AR3为千兆出口。现在需要实现以下需求&#xff1a; 希望vlan10的流量能够强制通过AR1作为业务的出口&#xff0c;vla…