【LSTM回归预测】attention机制LSTM时间序列回归预测【含Matlab源码 1992期】

news2024/12/23 7:06:33

⛄一、attention机制LSTM预测

1 总体框架
数字货币预测模型分为两部分,由LSTM模块和Attention模块组成。

2 LSTM模块
长短期记忆网络(LSTM)是一种特殊的递归神经网络(RNN)模型,是为了解决RNN模型梯度消失的问题而提出的。在传统的RNN模型当中,当时间跨度较长时,网络权重更新会十分缓慢。针对此问题,LSTM采用了“记忆单元”来存储记忆。下图是LSTM的简单示意图。其中,一个LSTM网络主要包括以下几部分:记忆细胞(Ct)、遗忘门(ft)、输入门(it)、输出门(Ot)。

3 Attention模块
在实际过程中,长时间序列特征的重要程度往往存在差异,而LSTM神经网络对于长时间序列输入没有区分。数字货币价格随着各种因素的变化在不断变化,不同时间点的特征对于数字货币价格预测的影响程度是不同的。在时间序列数据的处理中,Attention机制对长短期记忆网络(LSTM)输出的隐藏层向量ht进行加权求和。此处的权重大小可以理解为“不同时间点上的特征的重要程度”。其中,Hi表示第i时刻输入序列的隐藏层的状态值。然后,我们通过相似性函数Score(Hi,Hk)获取最后一个隐藏层输出Hk与其他每个时间点隐藏层输出Hi的相似度得分ei。相似性函数计算公式如下:
在这里插入图片描述
其中,Dot表示采用点积的方式进行计算。然后,通过公式(8)的soft-max函数可以计算出每个时间点的不同的注意力权重ai,在此基础上计算注意力权重ai与隐藏层状态的乘积,即可得到Attention层的输出向量C。
在这里插入图片描述

⛄二、部分源代码

% QRLSTM
% 数据集(列为特征,行为样本数目
% QRLSTM
% 数据集(列为特征,行为样本数目
clc
clear
close all
load(‘./Train.mat’)
%
Train(1,:) =[];
y = Train.demand;
x = Train{:,3:end};
[xnorm,xopt] = mapminmax(x’,0,1);
[ynorm,yopt] = mapminmax(y’,0,1);
x = x’;
xnorm = xnorm(:,1:1000);
ynorm = ynorm(1:1000);

k = 24; % 滞后长度

% 转换成2-D image
for i = 1:length(ynorm)-k

Train_xNorm(:,i,:) = xnorm(:,i:i+k-1);
Train_yNorm(i) = ynorm(i+k-1);
Train_y(i) = y(i+k-1);

end
Train_yNorm= Train_yNorm’;

ytest = Train.demand(1001:1170);
xtest = Train{1001:1170,3:end};
[xtestnorm] = mapminmax(‘apply’, xtest’,xopt);
[ytestnorm] = mapminmax(‘apply’,ytest’,yopt);
xtest = xtest’;
for i = 1:length(ytestnorm)-k
Test_xNorm(:,i,:) = xtestnorm(:,i:i+k-1);
Test_yNorm(i) = ytestnorm(i+k-1);
Test_y(i) = ytest(i+k-1);
end
Test_yNorm = Test_yNorm’;

clear k i x y
%

%% 训练集和验证集划分
TrainSampleLength = length(Train_yNorm);
validatasize = floor(TrainSampleLength * 0.1);
Validata_xNorm = Train_xNorm(:,end - validatasize:end,:);
Validata_yNorm = Train_yNorm(:,TrainSampleLength-validatasize:end);
Validata_y = Train_y(TrainSampleLength-validatasize:end);

Train_xNorm = Train_xNorm(:,1:end-validatasize,:);
Train_yNorm = Train_yNorm(:,1:end-validatasize);
Train_y = Train_y(1:end-validatasize);
%%

[params,~] = paramsInit(numhidden_units1,inputSize,outputSize); % 导入初始化参数

[~,validatastate] = paramsInit(numhidden_units1,inputSize,outputSize); % 导入初始化参数
[~,TestState] = paramsInit(numhidden_units1,inputSize,outputSize); % 导入初始化参数
% 训练相关参数
TrainOptions;
numIterationsPerEpoch = floor((TrainSampleLength-validatasize)/minibatchsize);
LearnRate = 0.01;
%% Loop over epochs.
figure
start = tic;
lineLossTrain = animatedline(‘color’,‘r’);
validationLoss = animatedline(‘color’,‘g’,‘Marker’,‘o’);
xlabel(“Iteration”)
ylabel(“Loss”)

% epoch 更新
iteration = 0;
for epoch = 1 : numEpochs

[~,state] = paramsInit(numhidden_units1,inputSize,outputSize);       % 每轮epoch,state初始化
disp(['Epoch: ', int2str(epoch)])

% batch 更新
for i = 1 : numIterationsPerEpoch

    
    dlX = gpuArray(Train_xNorm(:,idx,:));
    dlY = gpuArray(Train_yNorm(idx));
    [gradients,loss,state] = dlfeval(@Model2,dlX,dlY,params,state);
    
    % L2正则化

% L2regulationFactor = 0.000011;
% gradients = dlupdate( @(g,parameters) L2Regulation(g,parameters,L2regulationFactor),gradients,params);
% gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);

    [params,averageGrad,averageSqGrad] = adamupdate(params,gradients,averageGrad,averageSqGrad,iteration,LearnRate);
    
    
    
    % 验证集测试
    if iteration == 1 || mod(iteration,validationFrequency) == 0
        output_Ynorm = ModelPredict(gpuArray(Validata_xNorm),params,validatastate);
        lossValidation = mse(output_Ynorm, gpuArray(Validata_yNorm));
    end

    
    % 作图(训练过程损失图)--------------------------********————————————————————————————————————————————————

end

% 每轮epoch 更新学习率
if mod(epoch,5) == 0
    LearnRate = LearnRate * LearnRateDropFactor;
end

end

%% 训练集
Predict_yNorm = Model2Predict(gpuArray(Train_xNorm),params,TestState);
Predict_yNorm = extractdata(Predict_yNorm);

Predict_y = mapminmax(‘reverse’,Predict_yNorm,yopt);
%
figure
plot(Predict_y,‘r’);
hold on
plot(Train_y,‘g’)
legend(‘训练集预测值’,‘训练集实际值’)
%% 验证集
Predict_yNorm = Model2Predict(gpuArray(Validata_xNorm),params,TestState);
Predict_yNorm = extractdata(Predict_yNorm);

Predict_y = mapminmax(‘reverse’,Predict_yNorm,yopt);
%
figure
plot(Predict_y,‘r’);
hold on
plot(Validata_y,‘g’)
legend(‘验证集预测值’,‘验证集实际值’)

%% predict(测试集)
Predict_yNorm = Model2Predict(gpuArray(Test_xNorm),params,TestState);
Predict_yNorm = extractdata(Predict_yNorm);

Predict_y = mapminmax(‘reverse’,Predict_yNorm,yopt);
%
figure
plot(Predict_y,‘r’);
hold on
plot(Test_y,‘g’)
legend(‘测试集预测值’,‘测试集实际值’)


## ⛄三、运行结果
![在这里插入图片描述](https://img-blog.csdnimg.cn/3c38b271e7d64878a92eee75aada5608.png#pic_center)
![在这里插入图片描述](https://img-blog.csdnimg.cn/6f3ea04d05ea401daf8b3dce1cb3c44d.png#pic_center)
![在这里插入图片描述](https://img-blog.csdnimg.cn/2e3f58b9ddc7484ba0cd95775caf50a9.png#pic_center)

## ⛄四、matlab版本及参考文献
**1 matlab版本**
2014a

**2 参考文献**
[1] 万浩睿,李成林,赵开明,王睿林.基于Attention机制的LSTM数字货币预测模型[J].长江信息通信. 2022,35(05)

**3 备注**
简介此部分摘自互联网,仅供参考,若侵权,联系删除
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/25960.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Codeforces Round #835 (Div. 4)】A——G题解

文章目录A Medium Number题意思路代码B Atillas Favorite Problem题意思路代码C Advantage题意思路代码D Challenging Valleys题意思路代码E Binary Inversions题意思路代码F Quests题意思路代码G SlavicGs Favorite Problem题意思路代码A Medium Number 题意 三个数&#xf…

窗口-视口转换(详细)

在QPainter中,绘制图像使用逻辑坐标绘制,然后再转化为绘图设备的物理坐标。 窗口(window):表示逻辑坐标下的相同矩形视口(viewport):表示物理坐标下的指定的一个任意矩形默认情况&am…

中国互联网大会天翼云展区大揭秘!

11月15日,由工业和信息化部、深圳市人民政府主办,中国互联网协会、广东省通信管理局、深圳市工业和信息化局等单位承办的2022(第二十一届)中国互联网大会在深圳开幕。本届大会以“发展数字经济 促进数字文明”为主题,聚…

单商户商城系统功能拆解34—应用中心—分销应用

单商户商城系统,也称为B2C自营电商模式单店商城系统。可以快速帮助个人、机构和企业搭建自己的私域交易线上商城。 单商户商城系统完美契合私域流量变现闭环交易使用。通常拥有丰富的营销玩法,例如拼团,秒杀,砍价,包邮…

在IDEA中搭建Spring5.2.x版本源码(~附带完整过程和图示~)

1.开发环境 JDK8IntelliJ IDEA 2019.1.4 gradle 5.6.4git 2.33.0 2.操作步骤 下载并安装git 进入https://git-scm.com/downloads,下载对应操作系统的git版本一直点击next安装即可记得配置环境变量 获取Spring源码 使用clone的方式将源码拉取到本地,方便…

Java递归查询树形结构(详解)

一.数据准备 数据库表结构如下所示, INSERT INTO jg_user.com_type(type_id, parent_id, type_name) VALUES (1, 0, 合伙企业); INSERT INTO jg_user.com_type(type_id, parent_id, type_name) VALUES (2, 0, 有限公司); INSERT INTO jg_user.com_type(type_id, p…

力扣(LeetCode)878. 第 N 个神奇数字(C++)

二分查找数论 数论知识——辗转相除法、容斥原理。 辗转相除求最大公约数&#xff0c;两数相乘除以最大公约数&#xff0c;就是最小公倍数。 容斥原理求最多不重复元素&#xff0c;最大不重复面积。 <小学数奥> 从数据范围里&#xff0c;用容斥原理找 a/ba/ba/b 的倍数个…

Pytorch 下 TensorboardX 使用

这里引用一下在其他地方看到的一些话&#xff1a; tensorboard做为Tensorflow中强大的可视化工具&#xff0c;已经被普遍使用。 但针对其余框架&#xff0c;例如Pytorch&#xff0c;以前一直没有这么好的可视化工具可用&#xff0c;PyTorch框架自己的可视化工具是Visdom&…

实验九 数据微积分与方程数值求解(matlab)

实验九 数据微积分与方程数值求解 1.1实验目的 1.2实验内容 1.3流程图 1.4程序清单 1.5运行结果及分析 1.6实验的收获与体会 1.1实验目的 1&#xff0c;掌握求数值导数和数值积分的方法&#xff1b; 2&#xff0c;掌握代数方程数组求解的方法&#xff1b; 3&a…

【Mysql】Centos 7.6安装Mysql8

这里centos为阿里云默认镜像。 一、卸载历史历史版本 1、检查是否有服务启动 # service mysqld status 2、停止mysql服务 # service mysqld stop 3、查看mysql历史安装组件 # rpm -qa|grep mysqlmysql-libs-5.1.71-1.el6.x86_64 4、卸载组件 # rpm -e --nodeps mysql…

2022世界VR产业大会圆满收官,酷雷曼惊艳亮相!

11月14日&#xff0c;由工业和信息化部、江西省人民政府联合主办的全球VR领域规模最大、规格最高、影响最广的年度盛会——2022世界VR产业大会在江西南昌圆满落下帷幕。 本次大会得到了党中央、国务院的高度重视&#xff0c;国务委员王勇出席大会开幕式并讲话&#xff1b;大会邀…

【转】DNS隧道检测特征

原文链接&#xff1a;DNS隧道检测特征总结 - 知乎 一、摘要 企业内网环境中&#xff0c;DNS协议是必不可少的网络通信协议之一&#xff0c;为了访问互联网和内网资源&#xff0c;DNS提供域名解析服务&#xff0c;将域名和IP地址进行转换。网络设备和边界防护设备在一般的情况…

C++:内存管理:C++内存管理详解

C语言内存管理是指&#xff1a;对系统的分配、创建、使用这一系列操作。在内存管理中&#xff0c;由于是操作系统内存&#xff0c;使用不当会造成很麻烦的后果。本文将从系统内存的分配、创建出发&#xff0c;并且结合例子来说明内存管理不当会造成的结果以及解决方案。 一&am…

【Spring】Spring AOP的实现原理

目录 什么是AOP AOP的作用 AOP的优点 AOP框架 Spring AOP AspectJ 术语 1.Target ——目标类 2.Joinpoint ——连接点 3.Pointcut——切入点 4.Advice——通知/增强 5.Weaving——植入 6.Proxy——代理类 7.Aspect——切面 底层逻辑 开发流程 1.导入依…

八、手把手教你搭建SpringCloudAlibaba之Sentinel服务降级之慢调用

SpringCloud Alibaba全集文章目录&#xff1a; 零、手把手教你搭建SpringCloudAlibaba项目 一、手把手教你搭建SpringCloud Alibaba之生产者与消费者 二、手把手教你搭建SpringCloudAlibaba之Nacos服务注册中心 三、手把手教你搭建SpringCloudAlibaba之Nacos服务配置中心 …

贪心算法应用

1. 算法思想 贪心算法一般分为如下四步&#xff1a; 将问题分解为若干个子问题找出适合的贪心策略求解每一个子问题的最优解将局部最优解堆叠成全局最优解 即选择每一阶段的局部最优&#xff0c;从而达到全局最优。 2. 最大自序和 题目描述 题目链接 给你一个整数数组 n…

用Servlet 编写hello world

第一次接触 Servlet&#xff0c;使用 Servlet 编写代码并配合 Tomcat 在浏览器中展示代码效果&#xff0c;也算是自己这么长时间学习java的一次飞升吧。 本文总结了六步使用 Servlet 编写 helloworld步骤&#xff0c;希望给初学Servlet的同学一些帮助。 0、准备工作 1&#xff…

Python 随机函数random详解

介绍这7个随机数的方法应用&#xff1a; 1、random.random&#xff08;&#xff09;查看源码 说明&#xff1a;用于生成一个0到1的随机符点数: 0 < x < 1.0 import random for i in range(10):print(random.random()) 2、random.uniform&#xff08;&#xff09;查看源码…

深度学习(14)—— 关于Tensorboard

深度学习&#xff08;14&#xff09;—— 关于Tensorboard 文章目录深度学习&#xff08;14&#xff09;—— 关于Tensorboard前言1. “一参数一图”2.“多参数一图”3. “一栏一图”4. “一栏多图”在模型训练过程中loss和acc都会发生变化&#xff0c;常常需要记录这些值&…

Spring Security(2)

您好&#xff0c;我是湘王&#xff0c;这是我的CSDN博客&#xff0c;欢迎您来&#xff0c;欢迎您再来&#xff5e; 前面已经把需要的环境准备好了&#xff0c;包括数据库和SQL语句&#xff0c;现在再来写代码。至于安装MySQL什么的就跳过去了&#xff0c;娘度子里面一大把。 先…