改进粒子群算法优化BP神经网络---回归+分类两种案例

news2025/2/26 23:06:31

今天采用改进的粒子群算法(LPSO)优化算法优化BP神经网络。本文选用的LPSO算法是之前作者写过的一篇文章:基于改进莱维飞行和混沌映射(10种混沌映射随意切换)的粒子群优化算法,附matlab代码

文章一次性讲解两种案例,回归分类。回归案例中,作者选用了一个经典的股票数据。分类案例中,选用的是公用的UCI数据集。

BP神经网络初始的权值阈值都是随机生成的,因此不一定是最佳的。采用智能算法优化BP神经网络的权值阈值,使得输入与输出有更加完美的映射关系,以此来提升BP神经网络模型的精度。本文采用LPSO算法对BP神经网络的权值阈值进行优化,并应用于实际的回归和分类案例中。

01 股票预测案例

案例虽然介绍的是股票预测,但是LPSO-BP预测模型是通用的,大家根据自己的数据直接替换即可。数据替换十分简单,代码注释中都写的非常清楚了。

股票数据特征有:开盘价,盘中最高价,盘中最低价,收盘价等。预测值为股票价格。股票数据整理代码已写好,想换成自己数据的童鞋不需要理解此代码,替换数据即可。下面直接上标准BP的预测结果和LPSO-BP的预测结果。

标准BP模型预测结果

6fa1b3ebafadb7a170810c6e03c3d895.png

可以看到标准BP神经网络的预测效果不是很理想,无法跟踪真实值偏差较大

LPSO-BP预测结果

可以看到LPSO-BP神经网络的预测值可以紧密跟随真实值,效果很好。

e503e6e1248bea10c3cd534fb48745aa.png

将真实值,BP预测值和LPSO-BP预测值放在一起,效果更加明显。99ac421b072fe2ccf4e81d5e133f2e10.png

接下来是一个LPSO优化前后的BP神经网络误差对比图。

9a394682607be3c2c7e8388e3c43ead1.png

LPSO-BP的迭代曲线,以预测值和真实值的MSE为目标函数。

abc69dabc42ef04f3a7b641c1df133dd.png

LPSO-BP预测模型的评价:可以看到,LPSO-BP方法在股票预测案例中可以很好地进行股票价格预测。

02 分类案例

接下来是LPSO-BP的分类案例,采用的数据是UCI数据集中的Balancescale.mat数据,该数据一共分为三类。接下来看结果。

标准BP模型分类结果

混淆矩阵结果图:

简单说一下这个图该怎么理解。请大家横着看,每行的数据加起来是100%,每行的数据个数加起来就是测试集中第一类数据的真实个数。以第一行为例,测试集中一共有12个数据是属于第一类的,而12个数据中,有8个预测正确,有1个预测成了第2类,3个预测成了第三类。其他行均这样理解。

d2415a67adf17ab9082f2c41cf426f80.png

下面这个图是另一种结果展现方式,在一些论文中会用这种方式展示结果。

245bb09052c5d026ed7185f1e664db60.png

LPSO-BP分类结果:

44cf9f279aba97781481e7ba9bd33835.png

9718d0ae9c0e49f7340a21a99fcac9c1.png

242ad698e408893c39a6b19f0a2c34f7.png

03 代码展示

%% 初始化
clear
close all
clc
warning off
addpath(genpath(pwd));
% rng(0)
load Balancescale.mat 
data = Balancescale;
data=data(randperm(size(data,1)),:);    %此行代码用于打乱原始样本,使训练集测试集随机被抽取,有助于更新预测结果。
input=data(:,2:end);
output1 =data(:,1);
for i=1:size(data,1)
    switch output1(i)
        case 1
            output(i,1)=1;
        case 2
            output(i,2)=1;
        case 3
            output(i,3)=1;
        case 4
            output(i,4)=1;
        case 5
            output(i,5)=1;
        case 6
            output(i,6)=1;
        case 7
            output(i,7)=1;
    end
end
%% 划分训练集和测试集
m=fix(size(data,1)*0.7);    %训练的样本数目
%训练集
input_train=input(1:m,:)';
output_train=output(1:m,:)';
% 测试集
input_test=input(m+1:end,:)';
output_test=output(m+1:end,:)';


%% 数据归一化
[inputn,inputps]=mapminmax(input_train,0,1);
inputn_test=mapminmax('apply',input_test,inputps);
dam = fix(size(inputn,2)*0.3);%选30%的训练集作为验证集
idx = randperm(size(inputn,2),dam);
XValidation = inputn(:,idx);
inputn(:,idx) = [];
YValidation = output_train(:,idx);
output_train(:,idx) = [];


%% 获取输入层节点、输出层节点个数
inputnum=size(input_train,1);
outputnum=size(output_train,1);
disp('/')
disp('神经网络结构...')
disp(['输入层的节点数为:',num2str(inputnum)])
disp(['输出层的节点数为:',num2str(outputnum)])
disp(' ')
disp('隐含层节点的确定过程...')


%确定隐含层节点个数
%采用经验公式hiddennum=sqrt(m+n)+a,m为输入层节点个数,n为输出层节点个数,a一般取为1-10之间的整数
acc = 0;
for hiddennum=fix(sqrt(inputnum+outputnum))+1:fix(sqrt(inputnum+outputnum))+10
    net0=newff(inputn,output_train,hiddennum);
    % 网络参数
    net0.trainParam.epochs=1000;            % 训练次数,这里设置为1000次
    net0.trainParam.lr=0.01;                % 学习速率,这里设置为0.01
    net0.trainParam.goal=0.0001;           % 训练目标最小误差,这里设置为0.0001
    net0.trainParam.show=25;                % 显示频率,这里设置为每训练25次显示一次
    net0.trainParam.mc=0.001;                % 动量因子
    net0.trainParam.min_grad=1e-8;          % 最小性能梯度
    net0.trainParam.max_fail=6;             % 最高失败次数
    net0.trainParam.showWindow = false;
    net0.trainParam.showCommandLine = false; 
    % 网络训练
    [net0,tr]=train(net0,inputn,output_train);
    an0=sim(net0,XValidation);  %验证集的仿真结果
    predict_label=zeros(1,size(an0,2));
    for i=1:size(an0,2)
        predict_label(i)=find(an0(:,i)==max(an0(:,i)));
    end
    outputt=zeros(1,size(YValidation,2));
    for i=1:size(YValidation,2)
        outputt(i)=find(YValidation(:,i)==max(YValidation(:,i)));
    end
    accuracy=sum(outputt==predict_label)/length(outputt);   %计算预测的确率
    disp(['隐含层节点数为',num2str(hiddennum),'时,验证集的准确率为:',num2str(accuracy)])
    
    %更新最佳的隐含层节点
    if acc<accuracy
        acc=accuracy;
        hiddennum_best=hiddennum;
    end
end
disp(['最佳的隐含层节点数为:',num2str(hiddennum_best),',验证集相应的训练集的准确率为:',num2str(acc)])


%% 构建最佳隐含层节点的BP神经网络
disp(' ')
disp('标准的BP神经网络:')
net0=newff(inputn,output_train,hiddennum_best,{'tansig','purelin'},'trainlm');% 建立模型
%网络参数配置
net0.trainParam.epochs=1000;         % 训练次数,这里设置为1000次
net0.trainParam.lr=0.01;                   % 学习速率,这里设置为0.01
net0.trainParam.goal=0.00001;                    % 训练目标最小误差,这里设置为0.0001
net0.trainParam.show=25;                % 显示频率,这里设置为每训练25次显示一次
net0.trainParam.mc=0.01;                 % 动量因子
net0.trainParam.min_grad=1e-6;       % 最小性能梯度
net0.trainParam.max_fail=6;               % 最高失败次数
% net0.trainParam.showWindow = false;
% net0.trainParam.showCommandLine = false;            %隐藏仿真界面
%开始训练
net0=train(net0,inputn,output_train);


%预测
an0=sim(net0,inputn_test); %用训练好的模型进行仿真
predict_label=zeros(1,size(an0,2));
    for i=1:size(an0,2)
        predict_label(i)=find(an0(:,i)==max(an0(:,i)));
    end
    outputt=zeros(1,size(output_test,2));
    for i=1:size(output_test,2)
        outputt(i)=find(output_test(:,i)==max(output_test(:,i)));
    end
    accuracy=sum(outputt==predict_label)/length(outputt);   %计算预测的确率  
    disp(['准确率为:',num2str(accuracy)])
%% 标准BP神经网络作图
% 画方框图
figure
confMat = confusionmat(outputt,predict_label);  %output_test是真实值标签
zjyanseplotConfMat(confMat.');  
xlabel('Predicted label')
ylabel('Real label')
% 作图
figure
scatter(1:length(predict_label),predict_label,'r*')
hold on
scatter(1:length(predict_label),outputt,'g^')
legend('预测类别','真实类别','NorthWest')
title({'BP神经网络的预测效果',['测试集正确率 = ',num2str(accuracy*100),' %']})
xlabel('预测样本编号')
ylabel('分类结果')
box on
set(gca,'fontsize',12)
%% LPSO优化算法寻最优权值阈值
disp(' ')
disp('LPSO优化BP神经网络:')


net=newff(inputn,output_train,hiddennum_best,{'tansig','purelin'},'trainlm');% 建立模型


%网络参数配置
net.trainParam.epochs=1000;         % 训练次数,这里设置为1000次
net.trainParam.lr=0.0001;                   % 学习速率,这里设置为0.01
net.trainParam.goal=0.000001;                    % 训练目标最小误差,这里设置为0.0001
net.trainParam.show=25;                % 显示频率,这里设置为每训练25次显示一次
net.trainParam.mc=0.01;                 % 动量因子
net.trainParam.min_grad=1e-6;       % 最小性能梯度
net.trainParam.max_fail=6;               % 最高失败次数
%% 初始化LPSO参数
popsize=20;   %初始种群规模
maxgen=100;   %最大进化代数
lb = -1;  %神经网络权值阈值的上下限
ub = 1;
numm = 2; %混沌系数
dim=inputnum*hiddennum_best+hiddennum_best+hiddennum_best*outputnum+outputnum;    %自变量个数
[Best_score,Best_pos,LPSO_curve]=LPSOforBP(numm,popsize,maxgen,lb,ub,dim,inputnum,hiddennum_best,outputnum,net,inputn,output_train,inputn_test,output_test);

代码中注释非常详细,有对神经网络构建的注释,有对LPSO-BP代码的注释,简单易懂。

代码附带UCI常用的数据集及其解释。大家可以自行尝试别的数据进行分类。附带LPSO在CEC2005函数的测试代码。

一次性获取两种案例代码。完整代码获取方式,后台回复关键词。

关键词 :

LPSOBP

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/829076.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mr. Cappuccino的第58杯咖啡——MacOS配置Maven和Java环境

MacOS配置Maven和Java环境 查看Mac使用的是哪个shell下载并准备Maven下载Maven配置前准备 下载并安装JDK下载JDK安装JDK 配置Maven和Java环境添加配置加载配置 验证环境 查看Mac使用的是哪个shell echo $SHELL如果使用的是bash&#xff0c;则使用以下命令 open ~/.bash_profi…

Java课题笔记~ MyBatis接口开发(代理开发)

使用XML文件进行开发&#xff0c;在调用SqlSession进行操作时&#xff0c;需要指定MyBatis映射文件中的方法&#xff0c;这种调用方式过于烦琐。为解决此问题&#xff0c;MyBatis提供了接口开发的方式。 接口开发的目的&#xff1a; 解决原生方式中的硬编码 简化后期执行SQL …

PHP语言基础知识(超详细)

文章目录 前言第一章 PHP语言学习介绍 1.1 PHP部署安装环境1.2 PHP代码工具选择 第二章 PHP代码基本语法 2.1 PHP函数知识介绍2.2 PHP常量变量介绍 2.2.1 PHP变量知识&#xff1a;2.2.2 PHP常量知识&#xff1a; 2.3 PHP注释信息介绍2.4 PHP数据类型介绍 2.4.1 整形数据类型2.4…

【elementui】解决el-select组件失去焦点blur事件每次获取的是上一次选中值的问题

目录 【问题描述】 【问题摘要】 【分析问题】 【完整Test代码】 【封装自定义指令】 ↑↑↑↑↑↑↑↑↑↑↑↑ 不想看解决问题过程的可点击上方【封装自定义指令】目录直接跳转获取结果即可~~~ 【问题描述】 一位朋友遇到这么一个开发场景&#xff1a;在表格里面嵌入el-…

Packet Tracer - 配置初始交换机设置

Packet Tracer - 配置初始交换机设置 拓扑 目标 第 1 部分&#xff1a;检验默认交换机配置 第 2 部分&#xff1a;配置基本交换机配置 第 3 部分&#xff1a;配置 MOTD 标语 第 4 部分&#xff1a;将配置文件保存到 NVRAM 第 5 部分&#xff1a;配置 S2 拓扑图 背景信息…

【Mybatis】XML映射文件

目录 11.3XML映射文件 1.select 2.insert、update、delete 3.Sql 4.parameters(参数) 5.resultMap 6.resultMap 使用示例 (1)在先前创建的数据库stu中创建表student 2&#xff0c;并插入若干条数据&#xff0c;代码如下&#xff1a; (2)创建工程mybatis_ResultMap_demo。 (…

Qt项目---简单的计算器

在这篇技术博客中&#xff0c;我们将介绍如何使用Qt框架实现一个简单的计算器应用。我们将使用C编程语言和Qt的图形用户界面库来开发这个应用&#xff0c;并展示如何实现基本的算术操作。 项目设置 首先&#xff0c;我们需要在Qt Creator中创建一个新的Qt Widgets应用程序项目…

7.物联网操作系统互斥信号量

优先级翻转问题 优先级翻转功能需求 优先级翻转功能实现 一。实验&#xff1a;优先级翻转问题 1.优先级翻转的解释 &#xff08;1&#xff09;有三个任务&#xff0c;一个任务L优先级最低&#xff0c;一个任务M优先级为中间&#xff0c;一个任务H优先级为最高。 &#xff08…

SpringBoot集成企业微信群聊机器人消息

目录 参考文档概述一、功能作用二、应用场景三、 群机器人发送限制四、创建机器人1、添加2、群机器人Webhook地址 五、发送消息1、文本 text请求体 图文连接 news 参考文档 官方文档 企业微信群机器人应用 概述 现在很多企业都在使用企业微信进行工作交流&#xff0c;自从企…

静态路由下一跳地址怎么确定(静态路由配置及讲解)

一、用到的所有命令及功能 ①ip route-static 到达网络地址 子网掩码 下一跳 // 配置静态路由下一跳指的是和当前网络直接连接的路由器的接口地址非直连网段必须全部做路由路径是手工指定的&#xff0c;在大规模网络上不能用&#xff0c;效率低&#xff0c;路径是固定的稳定的…

寻找旋转排序数组中的最小值——力扣153

文章目录 题目描述解法 二分法 题目描述 解法 二分法 int findMin(vector<int>& nums){int l0, rnums.size()-1;while(l<r){int mid (lr)/2;if(nums[mid]<nums[r]) rmid;else lmid1;}return nums[l];}

基于freeRTOS的垃圾桶(cubeMX)

前言&#xff1a;最近学习了freertos的任务、队列、互斥量、任务标志位等理论知识&#xff0c;看着都会就怕一练就废&#xff0c;于是打算做些项目巩固一下&#xff0c;加深一下对freertos知识的理解。 一、项目介绍 项目简单需求&#xff1a; 检测靠近时&#xff0c;垃圾桶自…

APUE学习62章终端(二): stty命令特殊字符终端标志

1. stty命令 stty命令的英文解释: 很明显stty有一个-F参数 所以准确的说: stty命令是设置当前终端驱动程序(也有可能直接配置了硬件&#xff0c;这点目前不清楚)的属性&#xff0c;使当前终端的驱动程序能够使能/去使能一些特殊字符的识别与处理等等 2. stty命令的结构 3. 终端…

Python web实战之 Django 的 ORM 框架详解

本文关键词&#xff1a;Python、Django、ORM。 概要 在 Python Web 开发中&#xff0c;ORM&#xff08;Object-Relational Mapping&#xff0c;对象关系映射&#xff09;是一个非常重要的概念。ORM 框架可以让我们不用编写 SQL 语句&#xff0c;就能够使用对象的方式来操作数据…

总结946

6:40起床 7&#xff1a;15~8:00早读&#xff0c;07年tex1,2 8:10~10:12 880第二章选填&#xff0c;题目有些综合&#xff0c;错的有些多呀&#xff0c;不要紧&#xff0c;拿下它&#xff0c;就有进步了。 10:28~11:27重做强化18讲6道题 12&#xff1a;10~2:15吃饭睡觉&…

MySQL 三大日志日志:undo log、redo log、binlog

目录 一条SQL的执行流程 为什么需要 undo log&#xff1f; undo log 是如何刷盘&#xff08;持久化到磁盘&#xff09;的&#xff1f; 为什么需要 Buffer Pool&#xff1f; Buffer Pool 缓存什么&#xff1f; Undo 页是记录什么&#xff1f; 查询一条记录&#xff0c;就只需…

代码随想录算法训练营第三十二天 | Leetcode随机抽题检测

Leetcode随机抽题检测 46 全排列未看解答自己编写的青春版重点题解的代码日后复习重新编写 78 子集未看解答自己编写的青春版重点题解的代码日后复习重新编写 17 电话号码的字母组合未看解答自己编写的青春版重点题解的代码日后复习重新编写 39 组合总和未看解答自己编写的青春…

SpringBoot项目增加logback日志文件

一、简介 在开发和调试过程中&#xff0c;日志是一项非常重要的工具。它不仅可以帮助我们快速定位和解决问题&#xff0c;还可以记录和监控系统的运行状态。Spring Boot默认提供了一套简单易用且功能强大的日志框架logback&#xff0c;本文将介绍如何在Spring Boot项目中配置和…

使用AIGC工具提升安全工作效率

新钛云服已累计为您分享760篇技术干货 在日常工作中&#xff0c;安全人员可能会涉及各种各样的安全任务&#xff0c;包括但不限于&#xff1a; 开发某些安全工具的插件&#xff0c;满足自己特定的安全需求&#xff1b;自定义github搜索工具&#xff0c;快速查找所需的安全资料、…

HTML基础介绍2

表单格式化 ctrld&#xff1a;复制选中行数的所有代码 ctrlx&#xff1a;删除代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>表单综合案例</title> </head> <body> <!--…