分类预测 | MATLAB实现基于Attention-LSTM的数据分类预测多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)
目录
- 分类预测 | MATLAB实现基于Attention-LSTM的数据分类预测多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)
- 效果一览
- 基本介绍
- 程序设计
- 参考资料
效果一览
基本介绍
分类预测 | MATLAB实现基于Attention-LSTM的数据分类预测多特征分类预测(长短期记忆网络融合注意力机制分类预测,含混淆矩阵图、分类图)
程序设计
- 完整程序和数据私信博主回复:Attention-LSTM的数据分类预测多特征分类预测
% 需要学习的参数
lstmweight = params.lstm.weights;
lstmrecurrentWeights = params.lstm.recurrentWeights;
lstmbias = params.lstm.bias;
% 不同批次间传递的参数(这里假设每一轮epoch中,不同Batch间的state是传递的,但不学习;
h0 = state.lstm.h0;
c0 = state.lstm.c0;
[Lstm_Y,h0,c0] = lstm(Train_X,h0,c0,lstmweight,lstmrecurrentWeights,lstmbias);
Htt = dlarray(Lstm_Y(:,:,1:end-1),'SBSC'); %转变成CNN输入格式,’SS为
%% Attention
Attentionweight = params.attention.weight; % 计算得分权重
Att = dlarray(squeeze(sum(CnnHttAtt .* dlarray(Attentionweight,'SC'),2)),'SBC'); %'C'维度为cnn卷积后的每一行
Ht = Lstm_Y(:,:,end); % 参考向量
HtAfter = dlarray(repmat(Ht,[1,1,50]),'SBC');
f = squeeze(sum(HtAfter.*Att,1));
socre = sigmoid(f); % 计算得分'CB'
socre = dlarray(repmat(socre,[1,1,6]),'CBS');
% 组成Vt
CnnAfterRow = dlarray(squeeze(CnnHtt),'CSB'); % 满足与socre维度一致
Vt = sum(CnnAfterRow .*socre,2);
Vt = squeeze(Vt);
%% Attention输出
weight1 = params.attenout.weight1;
bias1 = params.attenout.bias1;
weight2 = params.attenout.weight2;
bias2 = params.attenout.bias2;
Hthat = fullyconnect(Vt,weight1,bias1) + fullyconnect(Ht,weight2,bias2);
%% 全连接层前置层(降维)
LastWeight = params.fullyconnect.weight1;
LastBias = params.fullyconnect.bias1 ;
FullyconnectInput = fullyconnect(Hthat,LastWeight,LastBias);
FullyconnectInput = relu(FullyconnectInput);
参考资料
[1] https://blog.csdn.net/kjm13182345320/article/details/128163536?spm=1001.2014.3001.5502
[2] https://blog.csdn.net/kjm13182345320/article/details/128151206?spm=1001.2014.3001.5502