基于pytorch LSTM 的股票预测

news2025/1/11 21:51:31

学习记录于《PyTorch深度学习项目实战100例》
https://weibaohang.blog.csdn.net/article/details/127365867?ydreferer=aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L20wXzQ3MjU2MTYyL2NhdGVnb3J5XzEyMDM2MTg5Lmh0bWw%2Fc3BtPTEwMDEuMjAxNC4zMDAxLjU0ODI%3D

1.tushare

Tushare是一个免费、开源的Python财经数据接口包。主要用于提供股票及金融市场相关的数据,是国内常用的金融数据分析工具之一。Tushare对于投资研究、教学、项目背景调研等多种需求提供了极大的方便。

Tushare的主要特点有:

数据丰富:Tushare提供了包括实时行情数据、历史行情数据、基本面数据、宏观经济数据、公司基本信息、大盘指数数据、行业数据、新闻和公告等多种数据。

接口简单:使用Python调用Tushare接口相当简单。即便是初学者,只需数行代码,就能获取所需的金融数据。

社区活跃:由于Tushare的免费和开源特性,其社区相当活跃,有大量的开发者和爱好者进行维护和更新。

扩展性强:除了官方提供的数据接口,Tushare也支持自定义数据扩展,可以通过插件方式实现数据的快速扩展。

要使用Tushare,用户通常需要先进行安装,可以通过pip轻松完成。之后,通过简单的API调用,就可以获取所需的数据。

!pip install tushare

2.导入必要的依赖

# 1.加载股票数据
pro = ts.pro_api('your-key')
df = pro.daily(ts_code='000001.SZ', start_date='20130711', end_date='20230711')

df.index = pd.to_datetime(df.trade_date)  # 索引转为日期
df = df.iloc[::-1]  # 由于获取的数据是倒序的,需要将其调整为正序

如何获取your-key

注册网站
https://tushare.pro/register

注册后进入个人主页

在这里插入图片描述

找到接口token

在这里插入图片描述
复制后粘贴 到your-key 里面

在这里插入图片描述

# 创建一个 StandardScaler 实例
scaler = StandardScaler()
scaler_model = StandardScaler()

# 使用 scaler 对数据进行拟合和转换
data = scaler_model.fit_transform(np.array(df[['open', 'high', 'low', 'close']]).reshape(-1, 4))

# 使用 scaler 对 'close' 列进行拟合和转换
scaler.fit_transform(np.array(df['close']).reshape(-1, 1))

分开训练集和测试集

def split_data(data,timestep):
  dataX= [] # 用于存储输入序列
  dataY= [] # 用于存储输出数据

  #将整个窗口的数据保存到X中,将未来的一天保存到Y中
  for index in range(len(data)-timestep):
    # 提取时间窗口内的数据作为输入序列
    dataX.append(data[index: index + timestep])
    # 下一个时间步的数据作为输出
    dataY.append(data[index + timestep][3]) #输出为第四列('close' 列)

  dataX = np.array(dataX)
  dataY = np.array(dataY)

  #获取训练集大小
  train_size = int(np.round(0.8 * dataX.shape[0]))

  # 划分训练集,测试集
  x_train = dataX[: train_size, :].reshape(-1, timestep, 4)
  y_train = dataY[: train_size]

  x_test = dataX[train_size:, :].reshape(-1, timestep, 4)
  y_test = dataY[train_size:]


  return[x_train,y_train,x_test,y_test]
# 3.获取训练数据, x_train:1750,1,4
x_train,y_train,x_test,y_test = split_data(data,timestep=1)

探究数据集

在这里插入图片描述

# 4.将数据转为tensor
x_train_tensor = torch.from_numpy(x_train).to(torch.float32)
y_train_tensor = torch.from_numpy(y_train).to(torch.float32)
x_test_tensor = torch.from_numpy(x_test).to(torch.float32)
y_test_tensor = torch.from_numpy(y_test).to(torch.float32)

在这里插入图片描述

# 5.形成训练数据集
train_data = TensorDataset(x_train_tensor, y_train_tensor)
test_data = TensorDataset(x_test_tensor, y_test_tensor)

# 6.将数据加载成迭代器
batch_size =16
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size,
                                           True)

test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size,
                                          False)

LSTM

LSTM,即长短时记忆(Long Short-Term Memory)网络,是一种特殊的循环神经网络(Recurrent Neural Network, RNN)结构。LSTM旨在解决传统RNN在处理长序列数据时容易出现的梯度消失或梯度爆炸问题。

以下是LSTM的主要组成部分和功能:

遗忘门 (Forget Gate): 决定从单元状态中删除什么信息。使用sigmoid函数,输出0表示“完全忘记”,输出1表示“完全保留”。

输入门 (Input Gate): 有两部分组成。第一部分是sigmoid层,决定哪些值我们将更新。第二部分是tanh层,创建一个新的候选值向量,它可能被加到状态中。

单元状态 (Cell State): LSTM的核心部分,其在整个链上都有运行,只有一些少量的线性交互。单元状态类似于传送带,信息可以在其上自由流动,除非受到遗忘门或输入门的影响。

输出门 (Output Gate): 决定基于单元状态输出什么值。首先,sigmoid层决定我们将输出哪些部分。然后,将单元状态通过tanh(得到值在-1到1之间)并乘以sigmoid层的输出,这样只输出我们决定的部分。

LSTM的关键优势在于其能够记住长期的依赖关系。在许多序列任务中,当前的输出不仅仅依赖于前几个步骤,还可能依赖于很早之前的步骤。LSTM相比于普通的RNN更能够捕捉这些长期依赖关系。

实际上,LSTM的变种还有很多,例如GRU(Gated Recurrent Units)。不过,无论其结构如何调整,LSTM的核心思想是利用不同的门控结构来有选择地控制信息流,从而更好地学习和记住长期依赖关系
在这里插入图片描述

class LSTM(nn.Module):
  def __init__(self,input_dim,hidden_dim,num_layers,output_dim):
    super(LSTM,self).__init__()
    self.hidden_dim = hidden_dim #隐藏层大小
    self.num_layers = num_layers #LSTM层数
    # input_dim 为特征维度,就是每个时间点对应的特征数量,这里为4
    self.lstm = nn.LSTM(input_dim,hidden_dim,num_layers,batch_first=True)
    self.fc = nn.Linear(hidden_dim,output_dim)

  def forward(self,x):
    # Initialize hidden state and cell state
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
    c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)

    # Forward propagate through LSTM
    output, (h_n, c_n)= self.lstm(x, (h0, c0))
    batch_size, timestep, hidden_dim = output.shape

    # 将output变成 batch_size * timestep, hidden_dim
    output = output.reshape(-1, hidden_dim)
    output = self.fc(output)  # 形状为batch_size * timestep, 1
    output = output.reshape(timestep, batch_size, -1)
    return output[-1]  # 返回最后一个时间片的输出

model = LSTM(input_dim, hidden_dim, num_layers, output_dim)  # 定义LSTM网络

打印模型
在这里插入图片描述

定义损失函数

loss_function = nn.MSELoss()  # 定义损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 定义优化器

配置信息

timestep = 1  # 时间步长,就是利用多少时间窗口
batch_size = 16  # 批次大小
input_dim = 4  # 每个步长对应的特征数量,就是使用每天的4个特征,最高、最低、开盘、落盘
hidden_dim = 64  # 隐层大小
output_dim = 1  # 由于是回归任务,最终输出层大小为1
num_layers = 3  # LSTM的层数
epochs = 10
best_loss = 0
model_name = 'LSTM'
save_path = './{}.pth'.format(model_name)

train

# 8. model 训练
save_path = '/content/sample_data/best.pth'
for epoch in range(epochs):
  model.train()
  running_loss= 0 #初始loss值
  train_bar = tqdm(train_loader) #通过使用tqdm来展示进度条
  for data in train_bar:
    x_train, y_train = data #
    # print(x_train.shape,y_train.shape) torch.Size([16, 1, 4]) torch.Size([16])
    optimizer.zero_grad()
    y_train_pred = model(x_train)
    loss = loss_function(y_train_pred,y_train.reshape(-1,1))
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss) #每次你计算 loss 并更新 train_bar.desc,进度条的描述就会更新,从而实时显示你的训练 loss 和当前 epoch。
    torch.save(model.state_dict(),save_path)

加载模型

# 加载权重和偏置
model.load_state_dict(torch.load(save_path))

模型验证

# 模型验证
model.eval()
test_loss = 0
save_path_2="/content/sample_data/best_2.pth"
with torch.no_grad():
   test_bar = tqdm(test_loader)
   for data in test_bar:
      x_test, y_test = data
      y_test_pred = model(x_test)
      test_loss = loss_function(y_test_pred, y_test.reshape(-1, 1))

if test_loss < best_loss:
  best_loss = test_loss
  torch.save(model.state_dict(), save_path_2)

绘制模型

# 9.绘制结果 train 的结果
pred_value =model(x_train_tensor).detach().numpy().reshape(-1, 1)
true_value =y_train_tensor.detach().numpy().reshape(-1, 1)
plt.figure(figsize=(12, 8))
plt.plot(scaler.inverse_transform(pred_value, "b"),label="Preds_value")
plt.plot(scaler.inverse_transform(true_value, "r"),label="Data_value")
plt.legend()
plt.show()

在这里插入图片描述

绘制test的结果

y_test_pred = model(x_test_tensor)
plt.figure(figsize=(12, 8))
plt.plot(scaler.inverse_transform(y_test_pred.detach().numpy()), "b",label="pred_value")
plt.plot(scaler.inverse_transform(y_test_tensor.detach().numpy().reshape(-1, 1)), "r",label='true_value')
plt.legend()
plt.show()

在这里插入图片描述
代码ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/958050.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++项目:网络版本在线五子棋对战

目录 1.项目介绍 2.开发环境 3.核心技术 4. 环境搭建 5.websocketpp 5.1原理解析 5.2报文格式 5.3websocketpp常用接口介绍 5.4websocket服务器 6.JsonCpp使用 6.1Json数据格式 6.2JsonCpp介绍 7.MySQL API 7.1MySQL API介绍 7.2MySQL API使用 7.3实现增删改查…

lnmp架构-mysql2

4.mysql 组复制集群 首先对所有的节点重新初始化 因为对节点的数据一致性要求非常高 主从复制的时候 slave只会复制master的binlog日志 就是二进制日志 不会复制relay_log 在server1上 根据实际情况修改主机名和网段 log_slave_updateON 意思就是 当slave的sql线程做完之后…

深度学习基础篇 第一章:卷积

dummy老弟这几天在复习啊我也跟着他重新复习一轮。 这次打算学的细一点&#xff0c;虽然对工作没什么帮助&#xff0c;但是理论知识也能更扎实吧&#xff01; 从0开始的深度学习大冒险。 参考教程&#xff1a; https://www.zhihu.com/question/22298352 https://zhuanlan.zhih…

k8s 启动和删除pod

k8s创建pod pod的启动流程 流程图 运维人员向kube-apiserver发出指令&#xff08;我想干什么&#xff0c;我期望事情是什么状态&#xff09; api响应命令,通过一系列认证授权,把pod数据存储到etcd,创建deployment资源并初始化。(期望状态&#xff09; controller通过list-wa…

C++信息学奥赛1184:明明的随机数

#include <bits/stdc.h> using namespace std; int main() {int n; // 数组长度cin >> n; // 输入数组长度int arr[n]; // 定义整数数组&#xff0c;用于存储输入的整数// 输入数组元素for (int i 0; i < n; i){cin >> arr[i];}int e 0; // 计数器&…

长胜证券:政策暖风不断 静待春暖花开

长胜证券指出&#xff0c;经济数据的逐步企稳上升&#xff0c;能够提振商场对经济复苏的决心&#xff0c;同时弱复苏布景下&#xff0c;政策的刺激力度也将为商场走强供给良好的土壤。暖风持续发布下&#xff0c;多方力量也在悄然间发生变化&#xff0c;重视权重、金融板块回暖…

docker命令学习

docker vscode插件出现的问题 docker命令 docker images &#xff08;查看所有的镜像&#xff09; docker ps -a &#xff08;查看所有的容器&#xff09; docker ps &#xff08;查看运行的容器&#xff09; docker run imageID docker run --gpus all --shm-size8g -it imag…

什么是数字孪生?

推荐&#xff1a;使用 NSDT场景编辑器 快速搭建3D应用场景 走进一家汽车装配厂。看到工人将螺母逐渐减少到螺栓上。听到气动工具的嗡嗡声。观看原始的车身沿着生产线滑行&#xff0c;机器人卷起零件。 现在&#xff0c;在线启动其 3D 数字孪生。看到动画数字人类在完全相同但数…

大数据学习:kafkaManager功能详解

kafkaManager功能详解 一.添加集群 1.1 常用参数说明 下面已常用的选项作说明 1&#xff09;Enable JMX Polling 是否开启 JMX 轮训&#xff0c;该部分直接影响部分 kafka broker 和 topic 监控指标指标的获取&#xff08;生效的前提是 kafka 启动时开启了 JMX_PORT。主要影…

Vue基础1:生命周期汇总(vue2)

Description 生命周期图&#xff1a; 可以理解vue生命周期就是指vue实例从创建到销毁的过程&#xff0c;在vue中分为9个阶段&#xff1a;创建前/后&#xff0c;载入前/后&#xff0c;更新前/后&#xff0c;销毁前/后&#xff0c;其他&#xff1b;常用的有&#xff1a;created&…

Spring容器及实例化

一、前言 Spring 容器是 Spring 框架的核心部分&#xff0c;它负责管理和组织应用程序中的对象&#xff08;Bean&#xff09;。Spring 容器负责创建、配置和组装这些对象&#xff0c;并且可以在需要时将它们提供给应用程序的其他部分。 Spring 容器提供了两种主要类型的容器&…

【Eclipse】搭建python环境;运行第一个python程序helloword

目录 0.环境 1.需准备&搭建思路 2.搭建具体步骤 1&#xff09;查看是否安装过python 2&#xff09;安装eclipse 3&#xff09;安装和配置pyDev 3.创建第一个python程序具体步骤 1&#xff09;新建项目 2&#xff09;输入项目名字&#xff0c;和配置选项 3&#x…

用户角色权限demo后续出现问题和解决

将demo账号给到理解和蒋老师&#xff0c;测试的时候将登录人账号改了&#xff0c;结果登录不了了&#xff0c;后续还需要分配权限无法更改他人的账号和密码 将用户和权限重新分配&#xff08;数据库更改&#xff0c;不要学我&#xff09; 试着登录还是报一样的错&#xff0c;但…

OA项目之用户登录首页展示

目录 本章节目标&#xff1a;完成OA项目用户登录及首页展示 一.用户登录 User.java UserDao.java IUserDao.java UserAction.java login.jsp&#xff08;登录界面&#xff09; userManage.jsp (数据绑定&#xff0c;修改&#xff0c;删除) userEdit.jsp&#xff08;用…

基于相空间重构的混沌背景下微弱信号检测算法matlab仿真,对比SVM,PSO-SVM以及GA-PSO-SVM

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1 SVM 4.2 PSO-SVM 4.3 GA-PSO-SVM 5.算法完整程序工程 1.算法运行效果图预览 SVM: PSO-SVM: GA-PSO-SVM: 以上仿真图参考文献《基于相空间重构的混沌背景下微弱信号检测方法研究》 2.…

AAC和ADTS音频格式解析

1.ADTS是个啥 ADTS全称是(Audio Data Transport Stream),是AAC的一种十分常见的传输格式。 记得第一次做demux的时候,把AAC音频的ES流从FLV封装格式中抽出来送给硬件解码器时,不能播;保存到本地用pc的播放器播时,我靠也不能播。当时崩溃了,后来通过查找资料才知道。一般…

加速关断BJT开关电路

引言&#xff1a;BJT从导通到关闭存在一定的延时&#xff0c;在特定的场景中比如BJT电平转换&#xff0c;高频信号调理&#xff0c;这种延时存在很大的隐患&#xff0c;本节简述如何消除BJT的关断延时。 €1.延时的产生机理 类似于图15-1&#xff0c;晶体管从截止状态切换到导…

干货| ICML2023:作为自适应自进化规划器的扩散模型

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入&#xff01; 作者介绍 梁志烜 香港大学计算机系直博一年级学生&#xff0c;导师为罗平教授&#xff0c;研究兴趣是生成式机器学习&#xff0c;Embodied AI和Data-centric learning。 报告题目 作为自适应自进化规划器的扩散…

本地部署体验LISA模型(LISA≈图像分割基础模型SAM+多模态大语言模型LLaVA)

GitHub地址&#xff1a;https://github.com/dvlab-research/LISA 该项目论文paper reading&#xff1a;https://blog.csdn.net/Transfattyacids/article/details/132254770 在GitHub上下载源文件&#xff0c;进入下载的文件夹&#xff0c;打开该地址下的命令控制台&#xff0c;…

耐蚀点蚀镀铜工艺

引言 随着5G技术的推出&#xff0c;导致电子电路和IC基板在设计中要求更高的密度。由于5G应用的性质&#xff0c;这些设计中的高可靠性和出色的电气性能也越来越重要。为了满足5G应用和其他下一代设备平台的需求&#xff0c;逐渐建立了使用改良半加成加工(mSAP)制造电路板的制…