长短时记忆网络(LSTM)负荷预测项目(matlab)

news2024/9/25 23:15:46

目录

 

1. LSTM介绍  

2. 数据集准备及预处理

3.  LSTM模型搭建与训练

 4. 预测模型测试

1. LSTM介绍  

     长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。LSTM网络介绍

2. 数据集准备及预处理

       加载、清理和划分数据集。

DateTimeTemperatureHumidityWind Speedgeneral diffuse flowsdiffuse flowsZone 1 Power ConsumptionZone 2  Power ConsumptionZone 3  Power Consumption
1/1/2017 0:006.55973.80.0830.0510.11934055.716128.8820240.96
1/1/2017 0:106.41474.50.0830.070.08529814.6819375.0820131.08
1/1/2017 0:206.31374.50.080.0620.129128.119006.6919668.43
1/1/2017 0:306.121750.0830.0910.09628228.8618361.0918899.28
1/1/2017 0:405.92175.70.0810.0480.08527335.717872.3418442.41
1/1/2017 0:505.85376.90.0810.0590.10826624.8117416.4118130.12
1/1/2017 1:005.64177.70.080.0480.09625998.9916993.3117945.06
1/1/2017 1:105.49678.20.0850.0550.09325446.0816661.417459.28
1/1/2017 1:205.67878.10.0810.0660.14124777.7216227.3617025.54
1/1/2017 1:305.49177.30.0820.0620.11124279.4915939.2116794.22
1/1/2017 1:405.51677.50.0810.0510.10823896.7115435.8716638.07
close all
clear
clc
tbl = readtable("国外负荷预测数据集.csv");%读取负荷预测数据
tbl.DateTime = datetime(tbl.DateTime,'InputFormat','dd/MM/yyyy HH:mm');%修改读取时间的格式

tbl = rmmissing(tbl);%数据预处理
head(tbl)
tbl = tbl(:, [1 end-2:end]);%提取3个中心城区负荷消耗数据
head(tbl)
figure
stackedplot(tbl,'XVariable','DateTime')%绘制趋势分布图
title("国外负荷预测数据集")
data = groupSequences(tbl, "DateTime");
[train_data, val_data, test_data] = splitSequence(data);%划分训练测试验证集
muPredictors = mean(cat(2, train_data{:, 1}), 2);
sigmaPredictors = std(cat(2,train_data{:, 1}), 0, 2);

muResponses = mean(cat(2, train_data{:, 2}), 2);
sigmaResponses = std(cat(2, train_data{:, 2}), 0, 2);

for i = 1:size(train_data, 1)
    train_data{i, 1} = (train_data{i, 1} - muPredictors) ./ sigmaPredictors;
    train_data{i, 2} = (train_data{i, 1} - muResponses) ./ sigmaResponses;

    val_data{i, 1} = (val_data{i, 1} - muPredictors) ./ sigmaPredictors;
    val_data{i, 2} = (val_data{i, 1} - muResponses) ./ sigmaResponses;

    test_data{i, 1} = (test_data{i, 1} - muPredictors) ./ sigmaPredictors;
    test_data{i, 2} = (test_data{i, 1} - muResponses) ./ sigmaResponses;
end
负荷分布

groupSequences程序:

function data = groupSequences(tbl, groupByColumn)
arguments
    tbl table
    groupByColumn (1, 1) string
end

if isa(tbl{1, groupByColumn}, "datetime")
    indexes = unique(dateshift(tbl{:, groupByColumn}, "start", "day"), "rows", "stable");
else
    indexes = unique(tbl{:, groupByColumn}, "rows", "stable");
end
indexes = sort(indexes, "ascend");

numIdxs = length(indexes);
data = cell(numIdxs, 1);
if isa(tbl{1, groupByColumn}, "datetime")
    for idx = 1:numIdxs
        data{idx} = tbl{dateshift(tbl{:, groupByColumn}, "start", "day") == indexes(idx), (tbl.Properties.VariableNames ~= groupByColumn)}';
    end
else
    for idx = 1:numIdxs
        data{idx} = tbl{tbl{:, groupByColumn} == indexes(idx), (tbl.Properties.VariableNames ~= groupByColumn)}';
    end
end

end

splitSequence程序:

function [train, val, test] = splitSequence(data, val_perc, test_perc)
arguments
    data (:, 1) cell
    val_perc double = 0.1
    test_perc double = 0.1
end

len = size(data, 1);

train = cell(len, 2);
val = cell(len, 2);
test = cell(len, 2);

for i = 1:len
    steps = size(data{i}, 2);
    stepsTrain = floor((1 - val_perc - test_perc) * steps);
    stepsVal = floor(val_perc * steps);

    train{i, 1} = data{i}(:, 1:stepsTrain-1);
    train{i, 2} = data{i}(:, 2:stepsTrain);
    
    val{i, 1} = data{i}(:, (stepsTrain + 1):(stepsTrain + stepsVal - 1));
    val{i, 2} = data{i}(:, (stepsTrain + 2):(stepsTrain + stepsVal));

    test{i, 1} = data{i}(:, (stepsTrain + stepsVal + 1):(end - 1));
    test{i, 2} = data{i}(:, (stepsTrain + stepsVal + 2):end);
end

end

3.  LSTM模型搭建与训练

       负荷预测数据集包含3个区域负荷的基础特征。模型搭建:

features = 3;
% Hyperparameters
hidden_units = 256;
max_epochs = 3000;
epoch_drop_period = 30;
batch_size = 32;
grad_thresh = 1;
ilr = 1e-2;%学习率
layers = [
    sequenceInputLayer(features)
    fullyConnectedLayer(hidden_units)
    lstmLayer(hidden_units, "OutputMode", "sequence")
    dropoutLayer(0.5)
    fullyConnectedLayer(features)
    regressionLayer
    ]
模型参数分析

    模型训练超参数设置:优化器选择带动量的随机梯度下降算法

opts = trainingOptions("sgdm", ...
    "MaxEpochs", max_epochs, ...
    "MiniBatchSize", batch_size, ...
    "ValidationData", {val_data(:, 1), val_data(:, 2)}, ...
    "GradientThreshold", grad_thresh, ...
    "InitialLearnRate", ilr, ...
    "LearnRateSchedule", "piecewise", ...
    "LearnRateDropPeriod", epoch_drop_period, ...
    "Shuffle", "every-epoch", ...
    "Plots", "training-progress", ...
    "Verbose", true ...
    )
net = trainNetwork(train_data(:, 1), train_data(:, 2), layers, opts);
训练过程曲线

 4. 预测模型测试

         使用测试数据集进行预测并计算均方根误差(RMSE)。此外,从序列的RMSE绘制直方图,其显示与RMSE矩阵的特定值相对应的误差量。最后,绘制了测试数据集中第一个序列的地面真相和预测,以查看两者之间的差异。

test_preds = predict(net, test_data(:, 1));

rmse = zeros(size(test_preds, 1), 1);
for i = 1:size(test_preds,1)
    rmse(i) = sqrt(mean((test_preds{i} - test_data{i, 2}).^2,"all"));
end
mrmse = mean(rmse);
clear i

figure
histogram(rmse)
xlabel("RMSE")
ylabel("Frequency")
title("Test Mean RMSE := " + num2str(mrmse))

tbl1 = table(test_data{1, 2}(1, :)', test_data{1, 2}(2, :)', test_data{1, 2}(3, :)', 'VariableNames', ["Zone 1", "Zone 2", "Zone 3"]);
tbl2 = table(test_preds{1}(1, :)', test_preds{1}(2, :)', test_preds{1}(3, :)', 'VariableNames', ["Zone 1", "Zone 2", "Zone 3"]);
figure
stackedplot(tbl1)
title( "真实值")
stackedplot(tbl2)
title( "预测值")
save powerConsumptionNet.mat

博客中涉及一些网络资源,如有侵权请联系删除。

该项目实现过程中的不足之处:没有利用天气特征进行负荷预测(后续优化)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/97593.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每日一道LeetCode(一):两数之和

写在前面 Hello大家好, 我是【麟-小白】,一位软件工程专业的学生,喜好计算机知识。希望大家能够一起学习进步呀!本人是一名在读大学生,专业水平有限,如发现错误或不足之处,请多多指正&#xff0…

刷完这1000道JAVA面试题,让你成功逆袭上岸

内容涵盖:Java、MyBatis、ZooKeeper、Dubbo、Elasticsearch、Memcached、Redis、MySQL、Spring、Spring Boot、Spring Cloud、RabbitMQ、Kafka、Linux 等技术栈。 由于整个笔记比较全面,内容相当的多 ,这里仅展示面经中的面试真题&#xff0…

Keras深度学习实战(41)——语音识别

Keras深度学习实战(41)——语音识别0.前言1. 模型与数据集分析1.1 数据集分析1.2 模型分析2. 语音识别模型2.1 数据加载与预处理2.2 模型构建与训练小结系列链接0.前言 语音识别(Automatic Speech Recognition, ASR,或称语音转录文本)使声音…

openssl加密base64编码

openssl OpenSSL 是一个安全套接字层密码库,囊括主要的密码算法、常用的密钥和证书封装管理功能及SSL协议,并提供丰富的应用程序供测试或其它目的使用。 首先,要安装 openssl: centos命令: sudo yum install openssl-devel ubuntu命令&#x…

WebService基于Baidu OCR和Map API的导航服务

哈尔滨工业大学国家示范性软件学院 《面向服务的软件系统》大作业 项目题目: 基于OCR和地图API的路牌定位与导航服务 项目组成员: 姓名 学号 李启明 120L021920 完成日期: 2022年 12 月 15 日 1.选题 1.1 作业…

NUS CS5477 assignment1

课程链接三维视觉 作业任务任务 课程任务就一个,实现一个Linear Sweep Algorithm,这个算法是用来检测两张图片之间的对应点。 因为SIFT检测如果把检测点的数量增大,可能会存在一些错误错误检测点,所有通常把SIFT检测的点的数量…

内网穿透:在家远程ssh访问学校内部网服务器

注册一个cpolar账号 cpolar官网注册即可(邮箱即可) cpolar支持http/https/tcp协议,不限制流量(花生壳免费只能使用1G流量),也不需要公网ip,只要在服务器上安装客户端即可配置,免费&…

攻防世界-file_include

题目 访问路径获得源码 <?php highlight_file(__FILE__);include("./check.php");if(isset($_GET[filename])){$filename $_GET[filename];include($filename);} ?> 通过阅读php代码&#xff0c;我们明显的可以发现&#xff0c;这个一个文件包含的类型题…

Java项目:ssm校内超市管理系统

作者主页&#xff1a;源码空间站2022 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 本系统分为管理员与普通用户两种角色。采用后端SSM框架&#xff0c;前端BootStrap&#xff08;前后端不分离&#xff09;的系统架构模式&#x…

python中调用命令行执行外部程序

&#x1f31e;欢迎来到python的世界 &#x1f308;博客主页&#xff1a;卿云阁 &#x1f48c;欢迎关注&#x1f389;点赞&#x1f44d;收藏⭐️留言&#x1f4dd; &#x1f31f;本文由卿云阁原创&#xff01; &#x1f320;本阶段属于练气阶段&#xff0c;希望各位仙友顺利完成…

STM32的三种更新固件的方式

说明&#xff1a; stm32有三种更新固件的方式&#xff0c;分别为&#xff08;1&#xff09;DFU模式&#xff08; Development Firmware Upgrade 即“开发固件升级”&#xff09;&#xff1b;&#xff08;2&#xff09;SWD/JLINK 下载 &#xff08;3&#xff09;第三方bootload…

NoSQL数据库原理与应用综合项目——HBase篇

NoSQL数据库原理与应用综合项目——HBase篇 文章目录NoSQL数据库原理与应用综合项目——HBase篇0、 写在前面1、本地数据或HDFS数据导入到HBase2、Hbase数据库表操作2.1 Java API 连接HBase2.2 查询数据2.3 插入数据2.4 修改数据2.5 删除数据3、Windows远程连接HBase4、数据及源…

springboot常用组件集成

今天与大家分享spring-mybatis、reids集成&#xff0c;druid数据库连接池。如果有问题&#xff0c;望指教。 1. 创建项目 File -> New -> project ...Spring Initializr选择项目需要的第三方组件注&#xff1a;可以参考第二次课演示的操作步骤&#xff0c;有详细的拷图…

java药店网站药店系统药店源码刷脸支付源码

简介 首页&#xff0c;搜索商品&#xff0c;详情页&#xff0c;根据不同规格显示不同的商品价格&#xff0c;加入购物车&#xff0c;立即购买&#xff0c;评价列表展示&#xff0c;商品详情展示&#xff0c;商品评分&#xff0c;分类商品&#xff0c;标签查询&#xff0c;更多…

MapReduce 概述原理说明

文章目录MapReduce概述一、MapReduce定义二、MapReduce 优缺点1、MapReduce 优点(1)、MapReduce 易于编程(2)、良好的扩展性(3)、高容错性(4)、适合PB级以上的海量数据的离线处理2、MapReduce 缺点(1)、不擅长实时计算(2)、不擅长流式计算(3)、不擅长DAG(有向图)计算三、MapRed…

二叉树进阶

博主的博客主页&#xff1a;CSND博客 Gitee主页&#xff1a;博主的Gitee 博主的稀土掘金&#xff1a;稀土掘金主页 博主的b站账号&#xff1a;程序员乐 公众号——《小白技术圈》&#xff0c;回复关键字&#xff1a;学习资料。小白学习的电子书籍都在这。 目录根据二叉树创建字…

基于java+springmvc+mybatis+vue+mysql的协同过滤算法的电影推荐系统

项目介绍 基于协同过滤算法的电影推荐系统利用网络沟通、计算机信息存储管理&#xff0c;有着与传统的方式所无法替代的优点。比如计算检索速度特别快、可靠性特别高、存储容量特别大、保密性特别好、可保存时间特别长、成本特别低等。在工作效率上&#xff0c;能够得到极大地…

Hive自定义UDF函数

以下基于hive 3.1.2版本 Hive中自定义UDF函数&#xff0c;有两种实现方式&#xff0c;一是通过继承org.apache.hadoop.hive.ql.exec.UDF类实现&#xff0c;二是通过继承org.apache.hadoop.hive.ql.udf.generic.GenericUDF类实现。 无论是哪种方式&#xff0c;实现步骤都是&…

网上超市系统

开发工具(eclipse/idea/vscode等)&#xff1a; 数据库(sqlite/mysql/sqlserver等)&#xff1a; 功能模块(请用文字描述&#xff0c;至少200字)&#xff1a; 研究内容&#xff1a;设计开发简单购网上超市系统&#xff0c;采用Java语言&#xff0c;使用ySQL数据库&#xff0c; 实…

毕业设计 单片机家用燃气可视化实时监控报警仪 - 物联网 嵌入式 stm32

文章目录0 前言1 简介2 主要器件3 实现效果4 设计原理4.1 硬件部分4.2 软件部分5 部分核心代码6 最后0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff0c;这两…