2-1 MATLAB鮣鱼优化算法ROA优化LSTM超参数回归预测

news2025/4/1 12:47:37

本博客来源于CSDN机器鱼,未同意任何人转载。

更多内容,欢迎点击本专栏目录,查看更多内容。

目录

0.ROA原理

1.LSTM程序

2.ROA优化LSTM

3.主程序

4.结语


0.ROA原理

具体原理看原文,但是今天咱不用知道具体原理,只需要找到源码,然后改成优化LSTM的即可。下面是我从网上找到的源码。ROA是主要的代码,Cost是适应度函数,这个代码的是找Cost的最小值。

function [Fbest, Rbest,Convergence_curve]= ROA()
sizepop=30; % Number of search agents
maxgen=500; % Maximum number of iterations
lb=-100;
ub=100;
D=30;

%maxgen为最大迭代次数,
%sizepop为种群规模
%记D为维度,lb、 ub分别为搜索上、下限
R=ones(sizepop,D);%预设种群
for i= 1:D
    R(:,i)=lb + (ub-lb)*rand(sizepop,1);
end
for k= 1:sizepop
    Fitness(k)=Cost(R(k,:));%个体适应度
end
[Fbest,elite]= min(Fitness);%Fbest为最优适应度值
Rbest= R(elite,:);%最优个体位置
H=zeros(1,sizepop);%控制因子
ub=ones(1,D)*ub;
lb=ones(1,D)*lb;

%主循环
for iter= 1:maxgen
    Rpre= R;%记录上一代的位置
    V=2*(1-iter/maxgen);
    B= 2*V*rand-V;
    a=-(1 + iter/maxgen);
    alpha=rand*(a-1)+ 1;
    for i= 1:sizepop
        if H(i)==0
            dis = abs(Rbest-R(i,:));
            R(i,:)= R(i,:)+ dis* exp(alpha)*cos(2*pi* alpha);
        else
            RAND= ceil(rand*sizepop);%随机选择一个个体
            R(i,:)= Rbest -(rand*0.5*(Rbest + R(RAND,:))- R(RAND,:));
        end
        Ratt= R(i,:)+ (R(i,:)- Rpre(i,:))*randn;%作出小幅度移动
        %边界吸收
        for k=1:D
            Flag4ub= R(i,k)>ub(k);
            Flag4lb= R(i,k)<lb(k);
            R(i,k)=(R(i,k).*(~(Flag4ub + Flag4lb))) + ub(k).*Flag4ub + lb(k).*Flag4lb;
            Flag4ub= Ratt(1,k)> ub(k);
            Flag4lb= Ratt(1,k)<lb(k);
            Ratt(1,k)=(Ratt(1,k).*(~(Flag4ub + Flag4lb)))+ ub(k).*Flag4ub + lb(k).*Flag4lb;
        end
        Fitness(i)=Cost(R(i,:));
        Fitness_Ratt= Cost(Ratt);
        if Fitness_Ratt < Fitness(i)%改变寄主
            if H(i)==1
                H(i)=0;
            else
                H(i)=1;
            end
        else %不改变寄主
            A= B*(R(i,:)-rand*0.3*Rbest);
            R(i,:)=R(i,:)+A;
        end
        %边界吸收
        for k=1:D
            Flag4ub= R(i,k)>ub(k);
            Flag4lb= R(i,k)<lb(k);
            R(i,k)=(R(i,k).*(~(Flag4ub+ Flag4lb))) + ub(k).*Flag4ub + lb(k).*Flag4lb;
        end
    end
    %更新适应度值、位置
    [fbest,elite] = min(Fitness);
    %更新最优个体
    if fbest< Fbest
        Fbest= fbest;
        Rbest= R(elite,:);
    end
    Convergence_curve(iter)= Fbest;
end
end


function o = Cost(x)
o=sum(x.^2);
end

调用这个代码的主程序如下:

clear ;close all;clc;format compact
%鮣鱼优化算法(Remora Optimization Algorithm)
[BestF,BestP,Convergence_curve1]=ROA();
figure
semilogy(Convergence_curve1)

1.LSTM程序

首先建立一个LSTM网络,这次我们是做回归任务,数据是3输入1输出,构建一个含2个lstmlayer的LSTM网络,代码如下:

%% LSTM时间序列预测
clc;clear;close all
%%
load data
XTrain;%3*97
XTest;%3*68
YTrain;%1*97
YTest;%1*68
%% 参数设置
train=0;%为1就重新训练,否则加载训练好的模型进行预测
if train==1
    rng(0)
    numFeatures = size(XTrain,1);%输入节点数
    numResponses = size(YTrain,1);%输出节点数
    miniBatchSize = 16; %batchsize
    numHiddenUnits1 = 20;
    numHiddenUnits2 = 20;
    maxEpochs=100;
    learning_rate=0.005;
    layers = [ ...
        sequenceInputLayer(numFeatures)
        lstmLayer(numHiddenUnits1)
        lstmLayer(numHiddenUnits2)
        fullyConnectedLayer(numResponses)
        regressionLayer];
    options = trainingOptions('adam', ...
        'ExecutionEnvironment', 'cpu',...
        'MaxEpochs',maxEpochs, ...
        'MiniBatchSize',miniBatchSize, ...
        'InitialLearnRate',learning_rate, ...
        'GradientThreshold',1, ...
        'Shuffle','every-epoch', ...
        'Verbose',true,...
        'Plots','training-progress');

    net = trainNetwork(XTrain,YTrain,layers,options);
    save model/lstm net
else
    load model/lstm
end
YPred = predict(net,XTest,'ExecutionEnvironment', 'cpu');YPred=double(YPred);

从构建网络那里,我们发现,构建一个超级简单的LSTM,依旧有miniBatchSize ,numHiddenUnits1 ,numHiddenUnits2 ,maxEpochs,learning_rate共5个超参数需要设置,而网络越复杂,那需要优化的超参数也就更多,手动选择就算了,一般是选不出来,为此这篇博客采用ROA进行优化。

2.ROA优化LSTM

任意一个优化网路超参数的步骤都是通用的,步骤如下:

步骤1:知道要优化的参数的优化范围。显然就是上面提到的5个参数。代码如下,首先改写lb与ub,然后初始化的时候注意除了学习率,其他的都是整数。并将原来里面的边界判断,改成了Bounds函数,方便在计算适应度函数前转化成整数与小数。

function [Rbest,Convergence_curve,process]= ROAforlstm(X1,y1,Xt,yt)
D=5;
sizepop=5;%种群数量
maxgen=10;%寻优代数
%范围
lb=[1 1   1   1  0.001];%分别对batchsize、两个lstm隐含层节点 训练次数与学习率寻优
ub=[64 100 100 50  0.01];%这个分别代表5个参数的上下界,比如第一个参数的范围就是1-64

%
%maxgen为最大迭代次数,
%sizepop为种群规模
%记D为维度,lb、 ub分别为搜索上、下限
R=ones(sizepop,D);%预设种群
for i=1:sizepop%随机初始化速度,随机初始化位置
    for j=1:D
        if j==D%除了学习率 其他的都是整数
            R( i, j ) = (ub(j)-lb(j))*rand+lb(j);
        else
            R( i, j ) = round((ub(j)-lb(j))*rand+lb(j));
        end
    end
end

for k= 1:sizepop
    Fitness(k)=fitness(R(k,:),X1,y1,Xt,yt);%个体适应度
end
[Fbest,elite]= min(Fitness);%Fbest为最优适应度值
Rbest= R(elite,:);%最优个体位置
H=zeros(1,sizepop);%控制因子


%主循环
for iter= 1:maxgen
    Rpre= R;%记录上一代的位置
    V=2*(1-iter/maxgen);
    B= 2*V*rand-V;
    a=-(1 + iter/maxgen);
    alpha=rand*(a-1)+ 1;
    for i= 1:sizepop
        if H(i)==0
            dis = abs(Rbest-R(i,:));
            R(i,:)= R(i,:)+ dis* exp(alpha)*cos(2*pi* alpha);
        else
            RAND= ceil(rand*sizepop);%随机选择一个个体
            R(i,:)= Rbest -(rand*0.5*(Rbest + R(RAND,:))- R(RAND,:));
        end
        Ratt= R(i,:)+ (R(i,:)- Rpre(i,:))*randn;%作出小幅度移动
        %边界吸收
        R(i, : ) = Bounds( R(i, : ), lb, ub );%对超过边界的变量进行去除
        Ratt = Bounds( Ratt, lb, ub );%对超过边界的变量进行去除
        Fitness(i)=fitness(R(i,:),X1,y1,Xt,yt);
        Fitness_Ratt= fitness(Ratt,X1,y1,Xt,yt);
        if Fitness_Ratt < Fitness(i)%改变寄主
            if H(i)==1
                H(i)=0;
            else
                H(i)=1;
            end
        else %不改变寄主
            A= B*(R(i,:)-rand*0.3*Rbest);
            R(i,:)=R(i,:)+A;
        end
        
        R(i, : ) = Bounds( R(i, : ), lb, ub );%对超过边界的变量进行去除
        
    end
    %更新适应度值、位置
    [fbest,elite] = min(Fitness);
    %更新最优个体
    if fbest< Fbest
        Fbest= fbest;
        Rbest= R(elite,:);
    end
    process(iter,:)=Rbest;
    Convergence_curve(iter)= Fbest;
    iter,Fbest,Rbest
end

end

function s = Bounds( s, Lb, Ub)
temp = s;
dim=length(Lb);
for i=1:length(s)
    if i==dim%除了学习率 其他的都是整数
        temp(:,i) =temp(:,i);
    else
        temp(:,i) =round(temp(:,i));
    end
end

% 判断参数是否超出设定的范围

for i=1:length(s)
    if temp(:,i)>Ub(i) | temp(:,i)<Lb(i) 
        if i==dim%除了学习率 其他的都是整数
            temp(:,i) =rand*(Ub(i)-Lb(i))+Lb(i);
        else
            temp(:,i) =round(rand*(Ub(i)-Lb(i))+Lb(i));
        end
    end
end
s = temp;
end

步骤2:知道优化的目标。优化的目标是提高的网络的准确率,而ROA代码我们这个代码是最小值优化的,所以我们的目标可以是最小化LSTM的预测误差。预测误差具体是,测试集(或验证集)的预测值与真实值之间的均方差。

步骤3:构建适应度函数。通过步骤2我们已经知道目标,即采用ROA去找到5个值,用这5个值构建的网络,误差最小化。观察下面的代码,首先我们将ROA的值传进来,然后转成需要的5个值,然后构建网络,训练集训练、测试集预测,计算预测值与真实值的mse,将mse作为结果传出去作为适应度值。

function y=fitness(x,p,t,pt,tt)
rng(0)
numFeatures = size(p,1);%输入节点数
numResponses = size(t,1);%输出节点数
miniBatchSize = x(1); %batchsize
numHiddenUnits1 = x(2);
numHiddenUnits2 = x(3);
maxEpochs=x(4);
learning_rate=x(5);
layers = [ ...
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits1)
    lstmLayer(numHiddenUnits2)
    fullyConnectedLayer(numResponses)
    regressionLayer];
options = trainingOptions('adam', ...
    'ExecutionEnvironment', 'cpu',...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'InitialLearnRate',learning_rate, ...
    'GradientThreshold',1, ...
    'Shuffle','every-epoch', ...
    'Verbose',false);


net = trainNetwork(p,t,layers,options);

YPred = predict(net,pt,'ExecutionEnvironment', 'cpu');YPred=double(YPred);
[m,n]=size(YPred);
YPred=reshape(YPred,[1,m*n]);
tt=reshape(tt,[1,m*n]);

y =mse(YPred-tt);
% 以mse为适应度函数,优化算法目的就是找到一组超参数 使网络的mse最低
rng((100*sum(clock)))

3.主程序

%% ROA优化LSTM时间序列预测
clc;clear;close all;format compact
%%
load data

%% 采用ROA优化
optimization=1;%是否重新优化
if optimization==1
    [x ,fit_gen,process]=ROAforlstm(XTrain,YTrain,XTest,YTest);%分别对batchsize 隐含层节点 训练次数与学习率寻优
    save result/ROA_para_result x fit_gen process
else
    load result/ROA_para_result
end
%% 利用优化得到的参数重新训练,得到预测值
train=1;%是否重新训练
if train==1
    rng(0)
    numFeatures = size(XTrain,1);%输入节点数
    numResponses = size(YTrain,1);%输出节点数
    miniBatchSize = x(1); %batchsize
    numHiddenUnits1 = x(2);
    numHiddenUnits2 = x(3);
    maxEpochs=x(4);
    learning_rate=x(5);
    layers = [ ...
        sequenceInputLayer(numFeatures)
        lstmLayer(numHiddenUnits1)
        lstmLayer(numHiddenUnits2)
        fullyConnectedLayer(numResponses)
        regressionLayer];
    options = trainingOptions('adam', ...
        'ExecutionEnvironment', 'cpu',...
        'MaxEpochs',maxEpochs, ...
        'MiniBatchSize',miniBatchSize, ...
        'InitialLearnRate',learning_rate, ...
        'GradientThreshold',1, ...
        'Shuffle','every-epoch', ...
        'Verbose',true,...
        'Plots','training-progress');

    net = trainNetwork(XTrain,YTrain,layers,options);
    save model/ROAlstm net
else
    load model/ROAlstm
end
% 预测
YPred = predict(net,XTest,'ExecutionEnvironment', 'cpu');
YPred=double(YPred);


4.结语

优化网络超参数的格式都是这样的!只要会改一种,那么随便拿一份能跑通的优化算法,在不管原理的情况下,都能用来优化网络的超参数。晚一点我们再来写一个简单的CNN,并用这个算法来优化。更多内容【点击专栏】目录。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2324334.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Golang】第八弹----面向对象编程

&#x1f525; 个人主页&#xff1a;星云爱编程 &#x1f525; 所属专栏&#xff1a;Golang &#x1f337;追光的人&#xff0c;终会万丈光芒 &#x1f389;欢迎大家点赞&#x1f44d;评论&#x1f4dd;收藏⭐文章 前言&#xff1a;Go语言面向对象编程说明 Golang也支持面向对…

java基础以及内存图

java基础 命名&#xff1a; 大驼峰&#xff1a;类名 小驼峰&#xff1a;变量名方法名等其他的 全部大写&#xff1a;常量名字.. // 单行注释 /**/ 多行注释 变量类型 变量名 一、基本类型&#xff08;8个&#xff09; 整数&#xff1a;byte-8bit short-16bit int 32-b…

【嵌入式学习3】TCP服务器客户端 - UDP发送端接收端

目录 1、TCP TCP特点 TCP三次握手&#xff08;建立TCP连接&#xff09;&#xff1a; TCP四次握手【TCP断开链接的时候需要经过4次确认】&#xff1a; TCP网络程序开发流程 客户端开发&#xff1a;用户设备上的程序 服务器开发&#xff1a;服务器设备上的程序 2、UDP 为…

Linux之基础知识

目录 一、环境准备 1.1、常规登录 1.2、免密登录 二、Linux基本指令 2.1、ls命令 2.2、pwd命令 2.3、cd命令 2.4、touch命令 2.5、mkdir命令 2.6、rmdir和rm命令 2.7man命令 2.8、cp命令 2.9、mv命令 2.10、cat命令 2.11、echo命令 2.11.1、Ctrl r 快捷键 2…

llamafactory微调效果与vllm部署效果不一致如何解决

在llamafactory框架训练好模型之后&#xff0c;自测chat时模型效果不错&#xff0c;但是部署到vllm模型上效果却很差 这实际上是因为llamafactory微调时与vllm部署时的对话模板不一致导致的。 对应的llamafactory的代码为 而vllm启动时会采用大模型自己本身设置的对话模板信息…

WebSocket通信的握手阶段

1. 客户端建立连接时&#xff0c;通过 http 发起请求报文&#xff0c;报文表示请求服务器端升级协议为 WebSocket&#xff0c;与普通的 http 请求协议略有区别的部分在于如下的这些协议头&#xff1a; 上述两个字段表示请求服务器端升级协议为 websocket 协议。 2. 服务器端响…

分布式ID服务实现全面解析

分布式ID生成器是分布式系统中的关键基础设施&#xff0c;用于在分布式环境下生成全局唯一的标识符。以下是各种实现方案的深度解析和最佳实践。 一、核心需求与设计考量 1. 核心需求矩阵 需求 重要性 实现难点 全局唯一 必须保证 时钟回拨/节点冲突 高性能 高并发场景…

dom0运行android_kernel: do_serror of panic----failed to stop secondary CPUs 0

问题描述&#xff1a; 从日志看出,dom0运行android_kernel&#xff0c;刚开始运行就会crash,引发panic 解决及其原因分析&#xff1a; 最终问题得到解决&#xff0c;发现是前期在调试汇编阶段代码时&#xff0c;增加了汇编打印的指令&#xff0c;注释掉这些指令,问题得到解决。…

HarmonyOS NEXT——【鸿蒙原生应用加载Web页面】

鸿蒙客户端加载Web页面&#xff1a; 在鸿蒙原生应用中&#xff0c;我们需要使用前端页面做混合开发&#xff0c;方法之一是使用Web组件直接加载前端页面&#xff0c;其中WebView提供了一系列相关的方法适配鸿蒙原生与web之间的使用。 效果 web页面展示&#xff1a; Column()…

优选算法的慧根之翼:位运算专题

专栏&#xff1a;算法的魔法世界 个人主页&#xff1a;手握风云 一、位运算 基础位运算 共包含6种&(按位与&#xff0c;有0就是0)、|(按位或有1就是1)、^(按位异或&#xff0c;相同为0&#xff0c;相异为1)、~(按位取反&#xff0c;0变成1&#xff0c;1变成0)、<<(左…

图论问题集合

图论问题集合 寻找特殊有向图&#xff08;一个节点最多有一个出边&#xff09;中最大环路问题特殊有向图解析算法解析步骤 1 &#xff1a;举例分析如何在一个连通块中找到环并使用时间戳计算大小步骤 2 &#xff1a;抽象成算法注意 实现 寻找特殊有向图&#xff08;一个节点最多…

【数据结构】栈 与【LeetCode】20.有效的括号详解

目录 一、栈1、栈的概念及结构2、栈的实现3、初始化栈和销毁栈4、打印栈的数据5、入栈操作---栈顶6、出栈---栈顶6.1栈是否为空6.2出栈---栈顶 7、取栈顶元素8、获取栈中有效的元素个数 二、栈的相关练习1、练习2、AC代码 个人主页&#xff0c;点这里~ 数据结构专栏&#xff0c…

Redis设计与实现-哨兵

哨兵模式 1、启动并初始化sentinel1.1 初始化服务器1.2 使用Sentinel代码1.3 初始化sentinel状态1.4 初始化sentinel状态的master属性1.5 创建连向主服务器的网络连接 2、获取主服务器信息3、获取从服务器的信息4、向主从服务器发送信息5、接受主从服务器的频道信息6、检测主观…

C++进阶——封装哈希表实现unordered_map/set

与红黑树封装map/set基本相似&#xff0c;只是unordered_map/set是单向迭代器&#xff0c;模板多传一个HashFunc。 目录 1、源码及框架分析 2、模拟实现unordered_map/set 2.1 复用的哈希表框架及Insert 2.2 iterator的实现 2.2.1 iteartor的核心源码 2.2.2 iterator的实…

【算法day25】 最长有效括号——给你一个只包含 ‘(‘ 和 ‘)‘ 的字符串,找出最长有效(格式正确且连续)括号子串的长度。

32. 最长有效括号 给你一个只包含 ‘(’ 和 ‘)’ 的字符串&#xff0c;找出最长有效&#xff08;格式正确且连续&#xff09;括号子串的长度。 https://leetcode.cn/problems/longest-valid-parentheses/ 2.方法二&#xff1a;栈 class Solution { public:int longestValid…

Jenkins + CICD流程一键自动部署Vue前端项目(保姆级)

git仓库地址&#xff1a;参考以下代码完成,或者采用自己的代码。 南泽/cicd-test 拉取项目代码到本地 使用云服务器或虚拟机采用docker部署jenkins 安装docker过程省略 采用docker部署jenkins&#xff0c;注意这里的命令&#xff0c;一定要映射docker路径&#xff0c;否则无…

一款超级好用且开源免费的数据可视化工具——Superset

认识Superset 数字经济、数字化转型、大数据等等依旧是如今火热的领域&#xff0c;数据工作有一个重要的环节就是数据可视化。 看得见的数据才更有价值&#xff01; 现如今依旧有多数企业号称有多少多少数据&#xff0c;然而如果这些数据只是呆在冷冰冰的数据库或文件内则毫无…

RedHatLinux(2025.3.22)

1、创建/www目录&#xff0c;在/www目录下新建name和https目录&#xff0c;在name和https目录下分别创建一个index.htm1文件&#xff0c;name下面的index.html 文件中包含当前主机的主机名&#xff0c;https目录下的index.htm1文件中包含当前主机的ip地址。 &#xff08;1&…

【C++篇】类与对象(上篇):从面向过程到面向对象的跨越

&#x1f4ac; 欢迎讨论&#xff1a;在阅读过程中有任何疑问&#xff0c;欢迎在评论区留言&#xff0c;我们一起交流学习&#xff01; &#x1f44d; 点赞、收藏与分享&#xff1a;如果你觉得这篇文章对你有帮助&#xff0c;记得点赞、收藏&#xff0c;并分享给更多对C感兴趣的…

智慧运维平台:赋能未来,开启高效运维新时代

在当今数字化浪潮下&#xff0c;企业IT基础设施、工业设备及智慧城市系统的复杂度与日俱增&#xff0c;传统人工运维方式已难以满足高效、精准、智能的管理需求。停机故障、低效响应、数据孤岛等问题直接影响企业运营效率和成本控制。大型智慧运维平台&#xff08;AIOps, Smart…