回归预测 | MATLAB实现SSA-CNN麻雀算法优化卷积神经网络多输入单输出
目录
- 回归预测 | MATLAB实现SSA-CNN麻雀算法优化卷积神经网络多输入单输出
- 预测效果
- 基本介绍
- 模型描述
- 程序设计
- 参考文献
预测效果
基本介绍
MATLAB实现SSA-CNN麻雀算法优化卷积神经网络多输入单输出。
1 .data为数据集,输入7个特征,输出一个变量。
2.MainSSA_CNN.m为程序主文件,其余为函数文件无需运行。
3.命令窗口输出MAPE、RMSE和R2,可在下载区获取数据和程序内容。
4.麻雀算法优化卷积神经网络,优化学习率、迭代次数、批处理、卷积核大小 和数量、 全连接层单元数量。
模型描述
卷积神经网络(CNN)中超参数众多,人工选择比较困难,利用麻雀搜索算法(SSA)对卷积神经网络中的参数进行优化,消除人工操作的不确定性。本模型共优化8 个超参数,分别是迭代次数、学习率、第1 层卷积核大小和数量、第2 层卷积核大小和数量,以及2 个全连接层的神经元数量(conv 表示卷积层,fc 表示全连接层)。本文建立的模型组成包括输入层、2 层卷积层、2 层激活层、2 层全连接层和输出层。SSA CNN 模型预测具体实现步骤如下。
第1 步:对数据进行归一化处理。
第2 步:设定初始参数,包括种群中的个体总数、子群体数、每个子群体中的麻雀数、最大迭代次数、发现者的数量及SSA 其他参数等。
第3 步:初始化种群并定义适应度函数,以CNN的预测值与实际值的均方误差最小化作为适应度函数,SSA 的目的就是找到一组超参数,用这组超参数训练得到的CNN 的误差能够最小化。
第4 步:计算适应度函数值并排序。
第5 步:确定每个子群体中的最优解、最差解和全局最优解。
第6 步:更新麻雀位置,获取当前的新位置,如果新位置比以前的位置更好就更新它,若达到设定的最大迭代次数,则将其输出,否则返回继续寻优,直到得到最好的麻雀坐标。
第7 步:将寻优得到的麻雀坐标代入CNN 模型中,得到预测模型的输出。
程序设计
- 完整程序和数据私信博主。
%% 参数设置
pop=5; % 种群数
M=10; % 最大迭代次数
dim=9;%一共有9个参数需要优化,分别是学习率、迭代次数、batchsize、第一层卷积层的核大小、和数量、第2层卷积层的核大小、和数量,以及两个全连接层的神经元数量
lb= [0.001 10 16 1 1 1 1 1 1]; % 下边界
ub= [0.01 50 256 16 20 16 20 50 50]; % 上边界
% 学习率的范围是0.001-0.01 迭代次数的范围是10-50 batchsize的范围是16-256 核大小的范围是1-16 核数量的范围是1-20 全连接层的范围是1-50
P_percent = 0.2; %producers 在全部种群的占比
pNum = round( pop * P_percent ); % producers的数量
%初始化种群
for i = 1 : pop
for j=1:dim
if j==1%除了学习率 其他的都是整数
x( i, j ) = (ub(j)-lb(j))*rand+lb(j);
else
x( i, j ) = round((ub(j)-lb(j))*rand+lb(j));
end
end
fit( i )=fitness(x(i,:),P_train,T_train,P_test,T_test);
end
pFit = fit;
pX = x;
fMin=fit(1);
bestX = x( i, : );
for t = 1 : M
[ ~, sortIndex ] = sort( pFit );% Sort.从小到大
[fmax,B]=max( pFit );
worse= x(B,:);
r2=rand(1);
%%%%%%%%%%%%%5%%%%%%这一部位为发现者(探索者)的位置更新%%%%%%%%%%%%%%%%%%%%%%%%%
if(r2<0.8)%预警值较小,说明没有捕食者出现
for i = 1 : pNum %r2小于0.8的发现者的改变(1-20) % Equation (3)
r1=rand(1);
x( sortIndex( i ), : ) = pX( sortIndex( i ), : )*exp(-(i)/(r1*M));%对自变量做一个随机变换
x( sortIndex( i ), : ) = Bounds( x( sortIndex( i ), : ), lb, ub );%对超过边界的变量进行去除
fit( sortIndex( i ) )=fitness(x(sortIndex( i ),:),P_train,T_train,P_test,T_test);
end
else %预警值较大,说明有捕食者出现威胁到了种群的安全,需要去其它地方觅食
for i = 1 : pNum %r2大于0.8的发现者的改变
x( sortIndex( i ), : ) = pX( sortIndex( i ), : )+randn(1)*ones(1,dim);
x( sortIndex( i ), : ) = Bounds( x( sortIndex( i ), : ), lb, ub );
fit( sortIndex( i ) )=fitness(x(sortIndex( i ),:),P_train,T_train,P_test,T_test);
end
end
[ ~, bestII ] = min( fit );
bestXX = x( bestII, : );
%%%%%%%%%%%%%5%%%%%%这一部位为加入者(追随者)的位置更新%%%%%%%%%%%%%%%%%%%%%%%%%
for i = ( pNum + 1 ) : pop %剩下20-100的个体的变换 % Equation (4)
% i
% sortIndex( i )
A=floor(rand(1,dim)*2)*2-1;
if( i>(pop/2))%这个代表这部分麻雀处于十分饥饿的状态(因为它们的能量很低,也是是适应度值很差),需要到其它地方觅食
x( sortIndex(i ), : )=randn(1,dim).*exp((worse-pX( sortIndex( i ), : ))/(i)^2);
else%这一部分追随者是围绕最好的发现者周围进行觅食,其间也有可能发生食物的争夺,使其自己变成生产者
x( sortIndex( i ), : )=bestXX+(abs(( pX( sortIndex( i ), : )-bestXX)))*(A'*(A*A')^(-1))*ones(1,dim);
end
x( sortIndex( i ), : ) = Bounds( x( sortIndex( i ), : ), lb, ub );%判断边界是否超出
fit( sortIndex( i ) )=fitness(x(sortIndex( i ),:),P_train,T_train,P_test,T_test);
end
%%%%%%%%%%%%%5%%%%%%这一部位为意识到危险(注意这里只是意识到了危险,不代表出现了真正的捕食者)的麻雀的位置更新%%%%%%%%%%%%%%%%%%%%%%%%%
c=randperm(numel(sortIndex));%%%%%%%%%这个的作用是在种群中随机产生其位置(也就是这部分的麻雀位置一开始是随机的,意识到危险了要进行位置移动,
%处于种群外围的麻雀向安全区域靠拢,处在种群中心的麻雀则随机行走以靠近别的麻雀)
b=sortIndex(c(1:pop));
for j = 1 : length(b) % Equation (5)
if( pFit( sortIndex( b(j) ) )>(fMin) ) %处于种群外围的麻雀的位置改变
x( sortIndex( b(j) ), : )=bestX+(randn(1,dim)).*(abs(( pX( sortIndex( b(j) ), : ) -bestX)));
else
%处于种群中心的麻雀的位置改变
x( sortIndex( b(j) ), : ) =pX( sortIndex( b(j) ), : )+(2*rand(1)-1)*(abs(pX( sortIndex( b(j) ), : )-worse))/ ( pFit( sortIndex( b(j) ) )-fmax+1e-50);
end
x( sortIndex(b(j) ), : ) = Bounds( x( sortIndex(b(j) ), : ), lb, ub );
fit( sortIndex( b(j) ) )=fitness(x(sortIndex( b(j) ),:),P_train,T_train,P_test,T_test);
end
for i = 1 : pop
if ( fit( i ) < pFit( i ) )
pFit( i ) = fit( i );
pX( i, : ) = x( i, : );
end
if( pFit( i ) < fMin )
fMin= pFit( i );
bestX = pX( i, : );
end
end
t,fMin
Convergence_curve(t)=fMin;
参考文献
[1] https://blog.csdn.net/kjm13182345320/article/details/128713044?spm=1001.2014.3001.5501
[2] https://blog.csdn.net/kjm13182345320/article/details/128700127?spm=1001.2014.3001.5501
[3] https://blog.csdn.net/kjm13182345320/article/details/128688474?spm=1001.2014.3001.5501