回归预测 | MATLAB实现PSO-DNN粒子群算法优化深度神经网络的数据多输入单输出回归预测

news2024/12/27 10:49:02

回归预测 | MATLAB实现PSO-DNN粒子群算法优化深度神经网络的数据多输入单输出回归预测

目录

    • 回归预测 | MATLAB实现PSO-DNN粒子群算法优化深度神经网络的数据多输入单输出回归预测
      • 效果一览
      • 基本介绍
      • 模型描述
      • 程序设计
      • 参考资料

效果一览

1
2
3
4

5
6

基本介绍

回归预测 | MATLAB实现PSO-DNN粒子群算法优化深度神经网络的数据多输入单输出回归预测
MATLAB实现PSO-DNN粒子群算法优化深度神经网络的数据多输入单输出回归预测(Matlab完整程序和数据)
输入7个特征,输出1个,即多输入单输出;
运行环境Matlab2018及以上,运行主程序main即可,其余为函数文件无需运行,所有程序放在一个文件夹,data为数据集;
命令窗口输出RMSE、MAE、R2、MAPE。

模型描述

粒子群算法(Particle Swarm Optimization, PSO)是一种启发式优化算法,通常用于解决复杂的非线性优化问题。而分组卷积神经网络(Grouped Convolutional Neural Network, GCNN)是一种卷积神经网络的变体,可以将输入数据分组进行卷积操作,从而减少计算量和参数数量。
针对数据多输入单输出回归预测问题,可以使用粒子群算法来优化分组卷积神经网络的结构和参数,以提高预测准确性。具体步骤如下:

  • 定义目标函数
    首先需要定义一个目标函数,用于衡量深度神经网络的预测准确性。一般来说,可以使用均方根误差(RMSE)或者均方误差(MSE)作为目标函数。

  • 确定网络结构和参数
    接下来需要确定深度神经网络的结构和参数。可以将这些参数作为粒子的维度,每个粒子代表一个网络结构和参数组合。

  • 初始化粒子群
    随机生成一定数量的粒子,每个粒子代表一个网络结构和参数组合。每个粒子都有一个速度和位置,速度表示它的移动方向和速度,位置表示它的当前位置。

  • 更新粒子速度和位置
    根据当前粒子的位置和速度,计算下一时刻的速度和位置。更新公式如下:

v_i(t+1) = w * v_i(t) + c1 * rand() * (pbest_i - x_i(t)) + c2 * rand() * (gbest - x_i(t))
x_i(t+1) = x_i(t) + v_i(t+1)
其中,v_i(t)表示粒子i在时刻t的速度,x_i(t)表示粒子i在时刻t的位置,pbest_i表示粒子i历史上最优的位置,gbest表示整个粒子群历史上最优的位置,w、c1和c2是常数,rand()是一个在[0,1]范围内均匀分布的随机数。

  • 计算适应度
    根据当前粒子的位置,计算目标函数的值作为粒子的适应度。如果当前位置比历史最优位置更优,则更新历史最优位置。

  • 判断停止条件
    如果达到预设的停止条件(如达到最大迭代次数或目标函数值足够小),则停止算法并返回历史最优位置对应的网络结构和参数。

  • 重复步骤4-6
    如果没有达到停止条件,则重复步骤4-6,直到达到停止条件为止。
    最终得到的历史最优位置对应的网络结构和参数即为经过粒子群算法优化后的分组卷积神经网络,可以用于数据多输入单输出回归预测问题。

程序设计

  • 完整程序和数据下载:后台私信PSO-DNN粒子群算法优化深度神经网络的数据多输入单输出回归预测
%%  参数设置
% ----------------------  修改模型参数时需对应修改fical.m中的模型参数  --------------------------
options = trainingOptions('adam', ...           % Adam 梯度下降算法
         'MaxEpochs', 500, ...                  % 最大训练次数 500
         'InitialLearnRate', best_lr, ...       % 初始学习率 best_lr
         'LearnRateSchedule', 'piecewise', ...  % 学习率下降
         'LearnRateDropFactor', 0.5, ...        % 学习率下降因子 0.1
         'LearnRateDropPeriod', 400, ...        % 经过 400 次训练后 学习率为 best_lr * 0.5
         'Shuffle', 'every-epoch', ...          % 每次训练打乱数据集
         'ValidationPatience', Inf, ...         % 关闭验证
         'L2Regularization', best_l2, ...       % 正则化参数
         'Plots', 'training-progress', ...      % 画出曲线
         'Verbose', false);

%%  训练模型
net = trainNetwork(p_train, t_train, layers, options);

%%  仿真验证
t_sim1 = predict(net, p_train);
t_sim2 = predict(net, p_test );

%%  数据反归一化
T_sim1 = mapminmax('reverse', t_sim1, ps_output);
T_sim2 = mapminmax('reverse', t_sim2, ps_output);
T_sim1=double(T_sim1);
T_sim2=double(T_sim2);
%%  均方根误差
error1 = sqrt(sum((T_sim1 - T_train).^2) ./ M);
error2 = sqrt(sum((T_sim2 - T_test ).^2) ./ N);
%% 参数初始化
popsize=pop;              %种群规模
lenchrom=dim;              %变量字串长度
fun = fobj;  %适应度函数
pc=0.7;                  %设置交叉概率
pm=0.3;                  %设置变异概率
if(max(size(ub)) == 1)
   ub = ub.*ones(dim,1);
   lb = lb.*ones(dim,1);  
end
maxgen=Max_iter;   % 进化次数  

function [gbest,g,Convergence_curve]=PSO(N,T,lb,ub,dim,fobj)
%% 定义粒子群算法参数
% N 种群 T 迭代次数 
%% 随机初始化种群
D=dim;                   %粒子维数
c1=1.5;                 %学习因子1
c2=1.5;                 %学习因子2
w=0.8;                  %惯性权重

Xmax=ub;                %位置最大值
Xmin=lb;               %位置最小值
Vmax=ub;                %速度最大值
Vmin=lb;               %速度最小值
%%
%%%%%%%%%%%%%%%%初始化种群个体(限定位置和速度)%%%%%%%%%%%%%%%%

x=rand(N,D).*(Xmax-Xmin)+Xmin;
v=rand(N,D).*(Vmax-Vmin)+Vmin;
%%%%%%%%%%%%%%%%%%初始化个体最优位置和最优值%%%%%%%%%%%%%%%%%%%
p=x;
pbest=ones(N,1);
for i=1:N
    pbest(i)=fobj(x(i,:)); 
end
%%%%%%%%%%%%%%%%%%%初始化全局最优位置和最优值%%%%%%%%%%%%%%%%%%
g=ones(1,D);
gbest=inf;
for i=1:N
    if(pbest(i)<gbest)
        g=p(i,:);
        gbest=pbest(i);
    end
end
%%%%%%%%%%%按照公式依次迭代直到满足精度或者迭代次数%%%%%%%%%%%%%
for i=1:T
    i
    for j=1:N
        %%%%%%%%%%%%%%更新个体最优位置和最优值%%%%%%%%%%%%%%%%%
        if (fobj(x(j,:))) <pbest(j)
            p(j,:)=x(j,:);
            pbest(j)=fobj(x(j,:)); 
        end
        %%%%%%%%%%%%%%%%更新全局最优位置和最优值%%%%%%%%%%%%%%%
        if(pbest(j)<gbest)
            g=p(j,:);
            gbest=pbest(j);
        end
        %%%%%%%%%%%%%%%%%跟新位置和速度值%%%%%%%%%%%%%%%%%%%%%
        v(j,:)=w*v(j,:)+c1*rand*(p(j,:)-x(j,:))...
            +c2*rand*(g-x(j,:));
        x(j,:)=x(j,:)+v(j,:);
        %%%%%%%%%%%%%%%%%%%%边界条件处理%%%%%%%%%%%%%%%%%%%%%%
        if length(Vmax)==1
            for ii=1:D
                if (v(j,ii)>Vmax)  |  (v(j,ii)< Vmin)
                    v(j,ii)=rand * (Vmax-Vmin)+Vmin;
                end
                if (x(j,ii)>Xmax)  |  (x(j,ii)< Xmin)
                    x(j,ii)=rand * (Xmax-Xmin)+Xmin;
                end
            end           
        else
            for ii=1:D
                if (v(j,ii)>Vmax(ii))  |  (v(j,ii)< Vmin(ii))
                    v(j,ii)=rand * (Vmax(ii)-Vmin(ii))+Vmin(ii);
                end
                if (x(j,ii)>Xmax(ii))  |  (x(j,ii)< Xmin(ii))
                    x(j,ii)=rand * (Xmax(ii)-Xmin(ii))+Xmin(ii);
                end
            end
        end
            
    end
    %%%%%%%%%%%%%%%%%%%%记录历代全局最优值%%%%%%%%%%%%%%%%%%%%%
   Convergence_curve(i)=gbest;%记录训练集的适应度值

end

参考资料

[1] https://blog.csdn.net/kjm13182345320/article/details/129215161
[2] https://blog.csdn.net/kjm13182345320/article/details/128105718

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/717721.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一文教你搞懂性能测试常见指标

目录 1. 性能指标分类 2. 系统性能指标 2.1 响应时间 2.2 系统处理能力 2.3 吞吐量 2.4 并发用户数 2.5 错误率 3. 资源性能指标 3.1 CPU 3.2 内存 3.3 磁盘吐吞量 3.4 网络吐吞量 4. 中间件指标 5. 数据库指标 6. 稳定性指标 7. 可扩展性指标 8. 可靠性…

谷歌浏览器Crx插件库-https://crxdl.com/

地址&#xff1a;https://crxdl.com/ postman插件&#xff1a;talend json插件库&#xff1a;csdn插件 抓取视频&#xff1a;猫抓

C++primer(第五版)第十章(泛型算法)

10.1概述 大多数算法定义在头文件algorithm中.另外头文件numeric中定义了一组数值泛型算法. 一般情况下算法不直接操作容器,而是通过迭代器来对元素进行处理,因此迭代器令算法不依赖容器,但算法依赖于元素类型的操作. 泛型算法本身不会执行容器的操作,它们只会运行于迭代器之…

XD教程笔记

一、快捷键 选择&#xff1a; V 粘贴外观&#xff1a; ctrl alt V 矩形&#xff1a; R 组件&#xff1a; ctrl K 椭圆&#xff1a; E 向某一方向对齐&#xff1a; ctrl shift 方向键 钢笔&#xff1a; P 100%显示&#xff1a; ctrl 1 文本&#xff1a; T 锁定&a…

SAP ABAP 查表数据接口

查 SAP 表数据的接口 1.使用范例&#xff1a; 字段注释QUERY_TABLE查询的表名FIELDNAME查询的字段ROWCOUNT查询的行数ROWCOUNT查询的行数OPTIONS查询条件FIELDS查询字段的释义和字符长度DATA查询的数据TOTALROWS符合条件数据的行数 FIELDS 结果&#xff1a; 外围系统接口调用…

图像像素操作与二值化

目录 1、图像像素比较 1.1 比较函数 1.2 图像最大值最小值寻找 2、图像像素逻辑操作 3、图像二值化 3.1 固定阈值二值化 3.2 自适应阈值二值化 1、图像像素比较 1.1 比较函数 1.2 图像最大值最小值寻找 Mat img imread("F:/testMap/bijiao.png");Mat white i…

Bootstrap - 【echart】 统计图表基本使用

一. 前言 Bootstrap是一个流行的前端框架&#xff0c;而ECharts是一个流行的可视化库。 Bootstrap可以用来设计网站和应用程序的用户界面&#xff0c;而ECharts可以用来创建交互式和可视化的图表。 chart.js中文文档&#xff1a;http://www.bootcss.com/p/chart.js/docs/ 二.…

MYSQL的基础架构

了解MySQL&#xff08;超详细的MySQL工作原理 体系结构&#xff09; 1.MySQL体系结构 2.MySQL内存结构 3.MySQL文件结构 4.innodb体系结构 一、了解MySQL前你需要知道的 引擎是什么: MySQL中的数据用各种不同的技术存储在文件(或者内存)中。这些技术中的每一种技术都使用不同…

现货白银投资技巧实战教程

交易的实战技巧是指一些能让交易者获利的方法&#xff0c;当中一般都包含重要的操作纪律以及资金的配置策略&#xff0c;目标是要让投资者以合理的风险控制&#xff0c;来赢得持续的利润。现货白银投资技巧实战教程主要有以下几方面的内容&#xff1a; 1、充分了解交易细则。交…

腾讯云服务器新手入门_省钱入口_搭建网站全流程

腾讯云服务器新手指南从云服务器创建、远程连接到云服务器、安装操作系统、使用阿里云服务器建站教程等全流程&#xff0c;腾讯云服务器网分享腾讯云服务器从创建、使用到搭建网站全流程指南&#xff1a; 目录 一&#xff1a;腾讯云服务器创建 二&#xff1a;腾讯云服务器远…

leetcode84. 柱状图中最大的矩形(单调栈-java)

柱状图中最大的矩形 leetcode84. 柱状图中最大的矩形题目描述单调栈加数组优化栈结构解题代码演示用数组来优化栈结构,时间会更快 单调栈专题 leetcode84. 柱状图中最大的矩形 来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 链接&#xff1a;https://leetcode.cn/prob…

01、Linux运维发展与学习路线图

目录 一、Linux运维行业前景二、运维相关岗位三、Linux运维岗位薪酬四、Linux运维岗知识框架4.1、常见站点系统架构演变1 单机2 多机3 缓存4 向外扩展5 Docker 4.2 知识体系框架图4.3 技术人员成长的阶段4.4 方法论 一、Linux运维行业前景 流程化、标准化的工作越来越依赖于信…

结构光三维测量几种比较成熟的方法

1.飞行时间发 原理:通过直接测量光传播的时间,确定物体的面型。发射脉冲信号,接受发射回的光,计算距离。 精度:毫米级 优点:原理简单,可避免阴影和遮挡等问题,且仪器便携化。 缺点:精度相对较低 2.莫尔条纹法 原理:采用两组光栅,一个主光栅,一个基准光栅,通过…

vue + element 笔记

1.安装nodejs&#xff0c;cmd中运行 node -v 验证是否成功 2.安装cnpm&#xff0c;cmd中运行 npm install -g cnpm --registryhttps://registry.npm.taobao.org&#xff0c;cmd中 cnpm -v 验证是否成功 3.安装vue-cli&#xff0c;cmd中运行 cnpm install --global vue-cli&…

【Spark】介绍,部署与快速入门

文章目录 介绍核心模块Spark CoreSpark SQLSpark StreamingSpark MLlibSpark GraphX 部署命令行Web UI提交应用Local 模式Standalone配置文件添加 JAVA_HOME 环境变量和集群对应的 master 节点启动集群配置历史服务添加日志存储路径添加日志配置webui 配置高可用 Yarn模式配置文…

老照片修复:模糊褪色有划痕的老旧照片如何修复?

在我们的生活中&#xff0c;照片是记录我们生活的重要方式之一。无论是在手机相册里还是在家中的相册里&#xff0c;我们都有很多珍贵的照片&#xff0c;但是随着时间的推移&#xff0c;照片也会老化&#xff0c;甚至出现褪色、划痕、折痕、破损、发霉等情况&#xff0c;这些情…

java多线程使用与踩坑

SpringBoot使用多线程简单方法&#xff1a;地址 线程安全查阅资料参考&#xff1a;地址 背景&#xff1a; 经过上述资料查看&#xff0c;我想写个方法&#xff08;依靠notify()唤醒&#xff0c;依靠wait()等待&#xff09;实现两个线程轮流打印。 实现&#xff1a; 1.线程池配…

HCIA复习二---7月4

路由&#xff1a; 按照路由条目&#xff0c;逻辑选址。 控制层面&#xff1a;路由条目的加表&#xff1a;AD metric&#xff08;华为 priority cost&#xff09;&#xff1b; 数据层面&#xff1a;按照路由条目转发数据包---与操作---最长匹配---递归查找&#xff1b; 静态…

第四十三周周报

学习目标&#xff1a; latent-diffusion 代码 学习时间&#xff1a; 2023.06.17 - 2023.06.30 学习产出&#xff1a; 一、代码 1、前置知识&#xff1a;PyTorch Lightning执行顺序 执行顺序&#xff1a; trainer.fit(model)&#xff1a;开始训练模型。 prepare_data()&a…

教你如何将纬地数据与实景三维模型进行叠加

概述&#xff1a; 纬地是公路设计的常用软件&#xff0c;在国内的普及率很高。传统的纬地数据文件以二维线条形式呈现在CAD中。本文提出了一种新思路、新方法&#xff0c;即将纬地的设计成果与无人机航拍的高精度倾斜摄影模型叠加在一起&#xff0c;辅助设计方案复核。 ​纬地…