66、基于长短期记忆 (LSTM) 网络对序列数据进行分类

news2025/1/19 3:08:32

1、基于长短期记忆 (LSTM) 网络对序列数据进行分类的原理及流程

基于长短期记忆(LSTM)网络对序列数据进行分类是一种常见的深度学习任务,适用于处理具有时间或序列关系的数据。下面是在Matlab中使用LSTM网络对序列数据进行分类的基本原理和流程:

  1. 准备数据

    • 确保数据集中包含带有标签的序列数据,例如时间序列数据、文本数据等。
    • 将数据进行预处理和归一化,以便输入到LSTM网络中。
  2. 构建LSTM网络

    • 在Matlab中,可以使用内置函数 lstmLayer 来构建LSTM层。
    • 指定输入数据维度、LSTM单元数量、输出层大小等参数。
    • 通过 layers = [sequenceInputLayer(inputSize), lstmLayer(numHiddenUnits), fullyConnectedLayer(numClasses), classificationLayer()] 构建完整的LSTM分类网络。
  3. 定义训练选项

    • 设置训练选项,例如学习率、最大迭代次数、小批量大小等。
    • 使用 trainingOptions 函数来定义训练选项。
  4. 训练网络

    • 使用 trainNetwork 函数来训练构建好的LSTM网络。
    • 输入训练数据和标签,并使用定义好的训练选项进行训练。
  5. 评估网络性能

    • 使用测试数据评估训练好的网络的性能,可以计算准确率、混淆矩阵等。
    • 通过 classify 函数对新数据进行分类预测。
  6. 模型调优

    • 可以通过调整LSTM网络结构、训练参数等进行进一步优化模型性能。

在实际的应用中,可以根据具体数据和任务需求对LSTM网络进行调整和优化,以获得更好的分类性能。Matlab提供了丰富的工具和函数来支持LSTM网络的构建、训练和评估,利用这些工具可以更高效地完成序列数据分类任务。

2、基于长短期记忆 (LSTM) 网络对序列数据进行分类说明

使用 LSTM 神经网络对序列数据进行分类,LSTM 神经网络将序列数据输入网络,并根据序列数据的各个时间步进行预测。

 

3、加载序列数据

1)说明

使用 Waveform 数据集,训练数据包含四种波形的时间序列数据。每个序列有三个通道,且长度不同。

从 WaveformData 加载示例数据。

序列数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numTimeSteps×-numChannels 数值数组,其中 numTimeSteps 是序列的时间步,numChannels 是序列的通道数。标签数据是 numObservations×1 分类向量。

2)加载数据代码

load WaveformData 

3)绘制部分序列

代码

numChannels = size(data{1},2);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end

视图效果

5e726cb5c6544a79826ce1ec2d56f905.png

4)查看分类

实现代码

classNames = categories(labels)

classNames = 4×1 cell
    {'Sawtooth'}
    {'Sine'    }
    {'Square'  }
    {'Triangle'}

5)划分数据

说明

使用 trainingPartitions 函数将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据)

实现代码 

numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations,[0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XTest = data(idxTest);
TTest = labels(idxTest);

4、准备要填充的数据

1)说明

默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度

2)获取观测值序列长度代码

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,1);
end

3)序列长度排序代码

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);

4)查看序列长度

代码

figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

视图效果

28fa6a44a83743e8b30660fcbb63a929.png

 

5、定义 LSTM 神经网络架构

1)说明

将输入大小指定为输入数据的通道数。

指定一个具有 120 个隐藏单元的双向 LSTM 层,并输出序列的最后一个元素。

最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层。

2)实现代码

numHiddenUnits = 120;
numClasses = 4;

layers = [
    sequenceInputLayer(numChannels)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]

layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 3 dimensions
     2   ''   BiLSTM            BiLSTM with 120 hidden units
     3   ''   Fully Connected   4 fully connected layer
     4   ''   Softmax           softmax

6、指定训练选项

1)说明

使用 Adam 求解器进行训练。

进行 200 轮训练。

指定学习率为 0.002。

使用阈值 1 裁剪梯度。

为了保持序列按长度排序,禁用乱序。

在图中显示训练进度并监控准确度。

2)实现代码

options = trainingOptions("adam", ...
    MaxEpochs=200, ...
    InitialLearnRate=0.002,...
    GradientThreshold=1, ...
    Shuffle="never", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

7、训练 LSTM 神经网络

1)说明

使用 trainnet 函数训练神经网络

2)实现代码

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

3)视图效果 

cf4856b632c144a58ea5c904d194e96d.png

8、测试 LSTM 神经网络

1)对测试数据进行分类,并计算预测的分类准确度。

numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,1);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);

2)对测试数据进行分类,并计算预测的分类准确度。

scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);

3)计算分类准确度

acc = mean(YTest == TTest)

acc = 0.8700

4)混淆图中显示分类结果

figure
confusionchart(TTest,YTest)

0b999fc6591a4438bf3efd7c41517798.png 

9、总结

基于长短期记忆(LSTM)网络对序列数据进行分类是一种重要的深度学习任务,适用于处理具有序列关系的数据,如时间序列数据、自然语言处理等。以下是对使用LSTM网络进行序列数据分类的总结:

  1. LSTM网络结构

    • LSTM是一种适用于处理长期依赖问题的循环神经网络(RNN)变种,能够有效地捕捉序列数据中的长期依赖关系。
    • LSTM网络包含输入门、遗忘门、输出门等核心部分,通过这些门控机制来控制信息的输入、遗忘和输出。
  2. 数据准备

    • 准备带有标签的序列数据,确保数据格式正确且包含标签信息。
    • 进行数据预处理和归一化操作,以便于网络训练。
  3. 网络构建

    • 使用深度学习框架(如TensorFlow、Pytorch或Matlab)构建LSTM网络,定义输入层、LSTM层、全连接层和输出层。
    • 设置网络参数,包括输入维度、LSTM单元个数、输出类别数等。
  4. 模型训练

    • 使用标记好的数据集对构建好的LSTM网络进行训练。
    • 设置优化器、损失函数和训练参数,如学习率、迭代次数等。
    • 调整网络参数以提高模型性能,避免过拟合。
  5. 模型评估

    • 使用验证集或测试集对训练好的模型进行评估,计算准确率、精确率、召回率等指标。
    • 分析模型在不同类别上的表现,进行结果可视化分析。
  6. 模型应用和优化

    • 将训练好的模型用于实际应用中,对新数据进行分类预测。
    • 根据实际需求对模型进行调优和优化,如调整网络结构、训练参数或使用模型集成等方法。

综合来看,基于LSTM网络对序列数据进行分类是一种强大的方法,可在许多领域中发挥作用。通过合理设计网络结构、优化数据准备和训练过程,可以有效地构建出具有良好泛化能力的序列数据分类模型。

10、源代码

代码

%% 基于长短期记忆 (LSTM) 网络对序列数据进行分类
%使用 LSTM 神经网络对序列数据进行分类,LSTM 神经网络将序列数据输入网络,并根据序列数据的各个时间步进行预测。

%% 加载序列数据
%使用 Waveform 数据集,训练数据包含四种波形的时间序列数据。每个序列有三个通道,且长度不同。
%从 WaveformData 加载示例数据。
%序列数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numTimeSteps×-numChannels 数值数组,其中 numTimeSteps 是序列的时间步,numChannels 是序列的通道数。标签数据是 numObservations×1 分类向量。
load WaveformData 
%绘制部分序列
numChannels = size(data{1},2);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end
%查看类名称
classNames = categories(labels)
%划分数据
%使用 trainingPartitions 函数将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据),
numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations,[0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XTest = data(idxTest);
TTest = labels(idxTest);
%% 准备要填充的数据
%默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度
%获取观测值序列长度
numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,1);
end
%序列长度排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);
%查看序列长度
figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

%% 定义 LSTM 神经网络架构
%将输入大小指定为输入数据的通道数。
%指定一个具有 120 个隐藏单元的双向 LSTM 层,并输出序列的最后一个元素。
%最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层。
numHiddenUnits = 120;
numClasses = 4;

layers = [
    sequenceInputLayer(numChannels)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]
%% 指定训练选项
%使用 Adam 求解器进行训练。
%进行 200 轮训练。
%指定学习率为 0.002。
%使用阈值 1 裁剪梯度。
%为了保持序列按长度排序,禁用乱序。
%在图中显示训练进度并监控准确度。
options = trainingOptions("adam", ...
    MaxEpochs=200, ...
    InitialLearnRate=0.002,...
    GradientThreshold=1, ...
    Shuffle="never", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);
%%  训练 LSTM 神经网络
%使用 trainnet 函数训练神经网络
net = trainnet(XTrain,TTrain,layers,"crossentropy",options);
%% 测试 LSTM 神经网络
%对测试数据进行分类,并计算预测的分类准确度。
numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,1);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);
%对测试数据进行分类,并计算预测的分类准确度。
scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);
%计算分类准确度
acc = mean(YTest == TTest)
%混淆图中显示分类结果
figure
confusionchart(TTest,YTest)

工程文件

https://download.csdn.net/download/XU157303764/89499744

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1881135.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TomCat小型服务器安装

一、安装步骤 Tomcat官方站点: http://tomcat.apache.org 1、进入官方网站后获取安装包: 🤠tar.gz文件是Linux操作系统下的安装版本 🤠zip文件是Windows系统下的压缩版本 2、解压安装 解压到自己的文件夹中 3、安装JDK 设置环…

Go源码--context包

简介 Context 是go语言比较重要的且也是比较复杂的一个结构体,Context主要有两种功能: 取消信号:包括直接取消(涉及的结构体:cancelCtx ; 涉及函数:WithCancel)和携带截止日期的取消(涉及结构…

功能强大的声音模拟合成软件Togu Audio Line TAL-Mod 1.9.7

Togu Audio Line TAL一个虚拟模拟合成器,具有卓越的声音和几乎无限的调制能力。其特殊的振荡器模型能够创建广泛的声音,从经典的单声道到丰富的立体声引线、效果器和焊盘。路由可以使用虚拟跳线电缆来完成。只需连接调制输出以达到调制的目的。之后,您可以调整调制强度。您不…

dB分贝入门

主要参考资料: dB(分贝)定义及其应用: https://blog.csdn.net/u014162133/article/details/110388145 目录 dB的应用一、声音的大小二、信号强度三、增益 dB的应用 一、声音的大小 在日常生活中,住宅小区告知牌上面标示噪音要低…

Excel表格转Tex工具推荐

为了制作符合 SCI 论文要求的表格,直接用 LaTeX 编写通常比较复杂。我们可以先在 Excel 中绘制好所需的表格(最好加上边框)。最近我发现了一个非常好用的 Excel 转 LaTeX 工具,能够让 LaTeX 表格的编写变得非常方便。 工具&#…

数据资产治理的智能化探索:结合云计算、大数据、人工智能等先进技术,探讨数据资产治理的智能化方法,为企业提供可靠、高效的数据资产解决方案,助力企业提升竞争力

一、引言 在信息化时代,数据已成为企业最重要的资产之一。随着云计算、大数据、人工智能等先进技术的飞速发展,数据资产治理面临着前所未有的机遇与挑战。本文旨在探讨如何结合这些先进技术,实现数据资产治理的智能化,为企业提供…

X科网js逆向分析

登录抓包之后发现pwd字眼,直接搜索即可 通过$.md5(pwd)之后得到的加密结果就是我们的pwd参数 他说是md5我们不妨测试一下: 1)测试使用$.md5(1)加密数字1 得到c4ca4,说明就是$.md5(),md5加密 2)测试$.md5…

神经网络在机器学习中的应用:手写数字识别

机器学习是人工智能的一个分支,它使计算机能够从数据中学习并做出决策或预测。神经网络作为机器学习的核心算法之一,因其强大的非线性拟合能力而广泛应用于各种领域,包括图像识别、自然语言处理和游戏等。本文将介绍如何使用神经网络对MNIST数…

独一无二的设计模式——单例模式(Java实现)

1. 引言 亲爱的读者们,欢迎来到我们的设计模式专题,今天的讲解的设计模式,还是单例模式哦!上次讲解的单例模式是基于Python实现(独一无二的设计模式——单例模式(python实现))的&am…

【数据结构】C语言实现二叉树的基本操作——二叉树的层次遍历、求深度、求结点数……

C语言实现二叉树的基本操作 导读一、层次遍历1.1 算法思路1.2 算法实现1.2.1 存储结构的选择1.2.2 函数的三要素1.2.3 函数的实现 1.3 小结 二、求二叉树的深度2.1 层序遍历2.2 分治思想——递归 三、 求二叉树的结点数3.1 求二叉树的结点总数3.1.1 层序遍历3.1.2 分治思想——…

SpringBoot | 使用jwt令牌实现登录认证,使用Md5加密实现注册

对于登录认证中的令牌,其实就是一段字符串,那为什么要那么麻烦去用jwt令牌?其实对于登录这个业务,在平常我们实现这个功能时,可能大部分都是通过比对用户名和密码,只要正确,就登录成功&#xff…

【Python实战因果推断】9_元学习器4

目录 Double/Debiased Machine Learning Double/Debiased Machine Learning Double/Debiased ML 或 R-learner 可以看作是 FrischWaugh-Lovell 定理的改进版。其思路非常简单--在构建结果和治疗残差时使用 ML 模型 结果和干预残差: , 预估,预估 由于 …

Python pdfkit wkhtmltopdf html转换pdf 黑体字体乱码

wkhtmltopdf 黑体在html转换pdf时&#xff0c;黑体乱码&#xff0c;分析可能wkhtmltopdf对黑体字体不太兼容&#xff1b; 1.html内容如下 <html> <head> <meta http-equiv"content-type" content"text/html;charsetutf-8"> </head&…

springboot使用测试类报空指针异常

检查了Service注解&#xff0c;还有Autowired注解&#xff0c;还有其他注解&#xff0c;后面放心没能解决问题&#xff0c;最后使用RunWith(SpringRunner.class)解决了问题&#xff01;&#xff01; 真的是✓8了&#xff0c;烦死了这个✓8报错&#xff01;

Android Focused Window的更新

启动App时更新inputInfo/请求焦点窗口流程&#xff1a; App主线程调ViewRootImpl.java的relayoutWindow()&#xff1b;然后调用到Wms的relayoutWindow()&#xff0c;窗口布局流程。焦点窗口的更新&#xff0c;通过WMS#updateFocusedWindowLocked()方法开始&#xff0c;下面从这…

【Spring】DAO 和 Repository 的区别

DAO 和 Repository 的区别 1.概述2.DAO 模式2.1 User2.2 UserDao2.3 UserDaoImpl 3.Repository 模式3.1 UserRepository3.2 UserRepositoryImpl 4.具有多个 DAO 的 Repository 模式4.1 Tweet4.2 TweetDao 和 TweetDaoImpl4.3 增强 User 域4.4 UserRepositoryImpl 5.比较两种模式…

以太网交换机原理

没有配置&#xff0c;比较枯燥&#xff0c;二可以认识线缆&#xff0c; 三比较重要&#xff0c;慢慢理解&#xff0c;事半功倍。 各位老少爷们&#xff0c;在下给大家说段以太网交换机原理&#xff0c;说得不好大家多多包涵&#xff0c;说得好呢&#xff0c;大家叫个好&#x…

【每日一练】python运算符

1. 算术运算符 编写一个Python程序&#xff0c;要求用户输入两个数&#xff0c;并执行以下运算&#xff1a;加法、减法、乘法、求余、除法、以及第一个数的第二个数次方。将结果打印出来。 a input("请输入第一个数&#xff1a;") b input("请输入第二个数&…

诊断知识:UnconfirmedDTCLimit的使用

文章目录 前言UnconfirmedDTCLimit的含义UnconfirmedDTCLimit的使用UnconfirmedDTCLimit和Failed limit相等UnconfirmedDTCLimit小于Failed limit 总结 前言 在某OEM基础技术规范中&#xff0c;诊断需求经常会出现UnconfirmedDTCLimit这个词汇&#xff0c;但基础技术规范中并没…

菲尔兹奖得主测试GPT-4o,经典过河难题未能破解!最强Claude 3.5回答离谱!

目录 01 大言模型能否解决「狼-山羊-卷心菜」经典过河难题&#xff1f; 02 加大难度&#xff1a;100只鸡、1000只鸡如何&#xff1f; 01 大言模型能否解决「狼-山羊-卷心菜」经典过河难题&#xff1f; 最近&#xff0c;菲尔兹奖得主Timothy Gowers分享了他测试GPT-4o的经历&a…