如何用MATLAB搭建ResNet网络(复现论文)

news2024/11/24 16:34:14

文章目录

  • 前言
  • 基础工具
  • 网络搭建
  • ResNet网络代码
  • 完整代码
  • 总结
  • 参考文献


前言

之前暑假实习了3个月,后来又赶上开学一堆事,现在终于有时间记录一下学习经历。新的学期开始了,要继续努力。

因为最近要做一个无人机航迹分类的项目,所以找了很多论文,但是不知道具体用起来效果如何,所以想用Matlab复现一下,这里主要是复现的其中的ResNet。

论文:基于CNN的雷达航迹分类方法

基础工具

Matlab深度网络设计器
在这里插入图片描述
打开工具箱后,左侧是网络的模块,在搭建网络的时候只需要拖拽这些模块就能组成网络的主体,右侧的属性用来调节网络模块的参数,例如卷积核,步长等等。
在这里插入图片描述

网络搭建

从论文的图中我们可以得到网络的结构如下,但是其中用的是一维卷积,所以我们选取的是Convd_1d
在这里插入图片描述
按照论文中的网络结构所示,我们搭建出了网络的主体架构,但值得注意的是,在第一次相加那里,输入的数据是5维的特征向量,但是经过卷积的向量维度是64,那么在这里该如何处理呢,为了让2者顺利相加,在这里我采用的是全连接层把5维扩展到64维,但是具体能不能这么做还有待进一步验证,因为网上的相关资料比较少,所以在这里先挖一个坑。
在这里插入图片描述
网络搭建好之后,先点击分析,看看网络的参数有没有问题
在这里插入图片描述
如果没有报错的话直接点击左上角的生成代码就好了,工具箱会生成一个mlx脚本,脚本如下。

ResNet网络代码

创建层次图
创建层次图变量以包含网络层。

lgraph = layerGraph();

添加层分支
将网络分支添加到层次图中。每个分支均为一个线性层组。

tempLayers = sequenceInputLayer(5,"Name","sequence");
lgraph = addLayers(lgraph,tempLayers);


tempLayers = fullyConnectedLayer(64,"Name","fc");
lgraph = addLayers(lgraph,tempLayers);


tempLayers = [
    convolution1dLayer(3,64,"Name","conv1d_1","Padding","same")
    batchNormalizationLayer("Name","batchnorm_1")
    reluLayer("Name","relu_1")
    convolution1dLayer(3,64,"Name","conv1d_2","Padding","same")
    batchNormalizationLayer("Name","batchnorm_2")
    reluLayer("Name","relu_2")
    convolution1dLayer(3,64,"Name","conv1d_3","Padding","same")
    batchNormalizationLayer("Name","batchnorm_3")
    reluLayer("Name","relu_3")];
lgraph = addLayers(lgraph,tempLayers);


tempLayers = [
    additionLayer(2,"Name","addition_1")
    reluLayer("Name","relu_10")];
lgraph = addLayers(lgraph,tempLayers);


tempLayers = [
    convolution1dLayer(3,128,"Name","conv1d_4","Padding","same")
    batchNormalizationLayer("Name","batchnorm_4")
    reluLayer("Name","relu_4")
    convolution1dLayer(3,128,"Name","conv1d_5","Padding","same")
    batchNormalizationLayer("Name","batchnorm_5")
    reluLayer("Name","relu_5")
    convolution1dLayer(3,128,"Name","conv1d_6","Padding","same")
    batchNormalizationLayer("Name","batchnorm_6")
    reluLayer("Name","relu_6")
    convolution1dLayer(1,64,"Name","conv1d_10","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);


tempLayers = [
    additionLayer(2,"Name","addition_2")
    reluLayer("Name","relu_11")];
lgraph = addLayers(lgraph,tempLayers);


tempLayers = [
    convolution1dLayer(3,128,"Name","conv1d_7","Padding","same")
    batchNormalizationLayer("Name","batchnorm_7")
    reluLayer("Name","relu_7")
    convolution1dLayer(3,128,"Name","conv1d_8","Padding","same")
    batchNormalizationLayer("Name","batchnorm_8")
    reluLayer("Name","relu_8")
    convolution1dLayer(3,128,"Name","conv1d_9","Padding","same")
    batchNormalizationLayer("Name","batchnorm_9")
    reluLayer("Name","relu_9")
    convolution1dLayer(1,64,"Name","conv1d_11","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);


tempLayers = [
    additionLayer(2,"Name","addition_3")
    reluLayer("Name","relu_12")
    globalAveragePooling1dLayer("Name","gapool1d")
    fullyConnectedLayer(2,"Name","fc_final")  % 将输出调整为2类
    softmaxLayer("Name","softmax")
    classificationLayer("Name","classoutput")];
lgraph = addLayers(lgraph,tempLayers);


% 清理辅助变量
clear tempLayers;

连接层分支
连接网络的所有分支以创建网络图。

lgraph = connectLayers(lgraph,"sequence","fc");
lgraph = connectLayers(lgraph,"sequence","conv1d_1");
lgraph = connectLayers(lgraph,"fc","addition_1/in1");
lgraph = connectLayers(lgraph,"relu_3","addition_1/in2");
lgraph = connectLayers(lgraph,"relu_10","conv1d_4");
lgraph = connectLayers(lgraph,"relu_10","addition_2/in1");
lgraph = connectLayers(lgraph,"conv1d_10","addition_2/in2");
lgraph = connectLayers(lgraph,"relu_11","conv1d_7");
lgraph = connectLayers(lgraph,"relu_11","addition_3/in1");
lgraph = connectLayers(lgraph,"conv1d_11","addition_3/in2");

这就是网络的整体架构,但是因为真实的数据还没有,所以还需要生成数据,因为论文的要求,生成的是5维特征向量组,为了计算方便,就选取了5x1000的矩阵代表无人机和飞鸟的特征向量。
完整代码如下

完整代码

clear;
clc;
% Step 1: 生成5x1000的矩阵数据
numSamples = 1000;
numFeatures = 5;  % 原始每个样本的特征数为5
sequenceLength = 1000;

% 类别 1:生成5x1000的矩阵代表标签 1
X1 = randn(numFeatures, numSamples);  % 模拟5x1000数据,类别1

% 类别 2:生成5x1000的矩阵代表标签 2
X2 = randn(numFeatures, numSamples);  % 模拟5x1000数据,类别2

% 合并数据
X = [X1, X2];  % 5x2000的矩阵,总共有2000个样本
labels = [ones(1, numSamples), 2 * ones(1, numSamples)];  % 类别标签 1 和 2

% Step 2: 数据预处理
XTrain = num2cell(X, 1);  % 将数据转换为元胞数组,每个元胞代表一个5x1的特征向量
YTrain = categorical(labels');  % 将标签转化为分类格式

% Step 3: 将数据分为90%训练集和10%验证集
idx = randperm(2 * numSamples);
numTrain = round(0.9 * numel(labels));

XTrainSet = XTrain(idx(1:numTrain));  % 选择90%的数据作为训练集
YTrainSet = YTrain(idx(1:numTrain));

XValidationSet = XTrain(idx(numTrain+1:end));  % 剩余10%作为验证集
YValidationSet = YTrain(idx(numTrain+1:end));

% Step 4: 定义网络结构
lgraph = layerGraph();


tempLayers = sequenceInputLayer(5,"Name","sequence");
lgraph = addLayers(lgraph,tempLayers);

tempLayers = [
    convolution1dLayer(1,64,"Name","conv_projection","Padding","same")
    batchNormalizationLayer("Name","batchnorm_projection")
    reluLayer("Name","relu_projection")
    convolution1dLayer(3,64,"Name","conv1d_1","Padding","same")
    batchNormalizationLayer("Name","batchnorm_1")
    reluLayer("Name","relu_1")
    convolution1dLayer(3,64,"Name","conv1d_2","Padding","same")
    batchNormalizationLayer("Name","batchnorm_2")
    reluLayer("Name","relu_2")
    convolution1dLayer(3,64,"Name","conv1d_3","Padding","same")
    batchNormalizationLayer("Name","batchnorm_3")
    reluLayer("Name","relu_3")];
lgraph = addLayers(lgraph,tempLayers);

tempLayers = fullyConnectedLayer(64,"Name","fc_2");
lgraph = addLayers(lgraph,tempLayers);

tempLayers = additionLayer(2,"Name","addition_1");
lgraph = addLayers(lgraph,tempLayers);

tempLayers = [
    convolution1dLayer(3,128,"Name","conv1d_4","Padding","same")
    batchNormalizationLayer("Name","batchnorm_4")
    reluLayer("Name","relu_4")
    convolution1dLayer(3,128,"Name","conv1d_5","Padding","same")
    batchNormalizationLayer("Name","batchnorm_5")
    reluLayer("Name","relu_5")
    convolution1dLayer(3,128,"Name","conv1d_6","Padding","same")
    batchNormalizationLayer("Name","batchnorm_6")
    reluLayer("Name","relu_6")
    convolution1dLayer(1,64,"Name","conv1d_10","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);

tempLayers = additionLayer(2,"Name","addition_2");
lgraph = addLayers(lgraph,tempLayers);

tempLayers = [
    convolution1dLayer(3,128,"Name","conv1d_7","Padding","same")
    batchNormalizationLayer("Name","batchnorm_7")
    reluLayer("Name","relu_7")
    convolution1dLayer(3,128,"Name","conv1d_8","Padding","same")
    batchNormalizationLayer("Name","batchnorm_8")
    reluLayer("Name","relu_8")
    convolution1dLayer(3,128,"Name","conv1d_9","Padding","same")
    batchNormalizationLayer("Name","batchnorm_9")
    reluLayer("Name","relu_9")
    convolution1dLayer(1,64,"Name","conv1d_11","Padding","same")];
lgraph = addLayers(lgraph,tempLayers);

tempLayers = [
    additionLayer(2,"Name","addition_3")
    globalAveragePooling1dLayer("Name","gapool1d")
    fullyConnectedLayer(2,"Name","fc_1")
    softmaxLayer("Name","softmax")
    classificationLayer("Name","classoutput")];
lgraph = addLayers(lgraph,tempLayers);

% 清理辅助变量
clear tempLayers;

% Step 5: 设置网络连接
lgraph = connectLayers(lgraph,"sequence","conv_projection");
lgraph = connectLayers(lgraph,"sequence","fc_2");
lgraph = connectLayers(lgraph,"fc_2","addition_1/in1");
lgraph = connectLayers(lgraph,"relu_3","addition_1/in2");
lgraph = connectLayers(lgraph,"addition_1","conv1d_4");
lgraph = connectLayers(lgraph,"addition_1","addition_2/in1");
lgraph = connectLayers(lgraph,"conv1d_10","addition_2/in2");
lgraph = connectLayers(lgraph,"addition_2","conv1d_7");
lgraph = connectLayers(lgraph,"addition_2","addition_3/in1");
lgraph = connectLayers(lgraph,"conv1d_11","addition_3/in2");

% Step 6: 设置训练选项
options = trainingOptions('adam', ...
    'MaxEpochs', 600, ...
    'MiniBatchSize', 32, ...
    'InitialLearnRate', 1e-5, ...
    'L2Regularization', 1e-6, ...
    'ValidationData', {XValidationSet, YValidationSet}, ...
    'ValidationFrequency', 30, ...
    'Shuffle', 'every-epoch', ...
    'Plots', 'training-progress', ...
    'Verbose', false);

% Step 7: 训练网络
net = trainNetwork(XTrainSet, YTrainSet, lgraph, options);

% Step 8: 测试网络
YPred = classify(net, XTrain);
accuracy = sum(YPred == YTrain) / numel(YTrain);
disp(['Training Accuracy: ', num2str(accuracy * 100), '%']);

在这里插入图片描述
用的随机矩阵进行训练,所以结果也是在预料之中了

总结

在复现的过程中,遇到最多的问题就是网络结构运算之间的大小对不上,如果出现了报错,第一时间就应该去检查一下网络之间传递的参数是否合理。

参考文献

[1] 汪浩,窦贤豪,田开严,等.基于CNN的雷达航迹分类方法[J].舰船电子对抗,2023,46(05):70-74.DOI:10.16426/j.cnki.jcdzdk.2023.05.014.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2129387.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一周完成计算机毕业设计论文:高效写作技巧与方法(纯干货/总结与提炼)

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

Trie字符串统计(每周一类)

这节课我们学习Trie字符串。这个算法的主要应用就是字符串的快速存储和查找。我们通过下面这个题来讲 Tire字符串统计 ,另外说个题外话,本人是从ACwing里学习的算法知识,希望大家支持一下y总(ACwing大佬),如果觉得我这里的知识讲得…

Unity Apple Vision Pro 开发(十):通过图像识别锚定空间

XR 开发者社区链接: SpatialXR社区:完整课程、项目下载、项目孵化宣发、答疑、投融资、专属圈子 课程试看:https://www.bilibili.com/video/BV1mpH9eVErW 课程完整版,答疑仅社区成员可见,可以通过文章开头的链接加入…

另类动态规划

前言&#xff1a;一开始我根本想不到这个题目是一个动态规划的题目&#xff0c;而且我一开始的初始状态还写错了 我还忘记了写算法题的基本步骤&#xff0c;先看数据范围&#xff0c;再考虑能不能用动态规划写 题目地址 #include <bits/stdc.h> using namespace std; #de…

RTR_Chapter_4_上半部分

第四章 Transform 变换 变换&#xff08;transform&#xff09;是指以点、向量、颜色等实体作为输入&#xff0c;并以某种方式对其进行转换的一种操作。对于计算机图形学从业者而言&#xff0c;熟练掌握变换相关的知识是非常重要的。通过各种变换操作可以对物体、光源和相机进…

开源网安斩获CCIA中国网络安全创新创业大赛总决赛三等奖

近日&#xff0c;由中央网信办指导&#xff0c;中国网络安全产业联盟&#xff08;CCIA&#xff09;主办的2024年中国网络安全创新创业大赛总决赛及颁奖典礼在国家网络安全宣传周落下帷幕。开源网安“AI代码审核平台CodeSec V4.0” 凭借在AI方向的技术创新、技术突破及功能应用创…

数据库——MySQL概述

一、数据库 存储数据的仓库&#xff0c;数据是有组织的存储&#xff0c;简称database&#xff08;DB&#xff09; 二、数据库管理系统 操控和管理数据库的大型软件&#xff08;DBMS&#xff09; 三、SQL 操作关系型数据库的编程语言&#xff0c;定义了一套操作关系型数据库…

【2024】前端学习笔记1-HTML主体框架-文本标签

学习笔记 HTML主体框架标题标签:h段落标签:p加粗标签:strong、b斜体文本标签:em、i下划线标签:u上标、下标:sup、sub内联容器:span换行标签:brHTML主体框架 HTML主体框架 <!DOCTYPE html> <html lang="en"><head><meta charset="…

【Linux 19】线程概念

文章目录 &#x1f308; 一、线程的概念⭐ 1. 线程是什么⭐ 2. 线程的优点⭐ 3. 线程的缺点⭐ 4. 线程的异常⭐ 5. 线程的用途 &#x1f308; 二、进程和线程⭐ 1. 进程和线程的区别⭐ 2. 进程的多线程共享⭐ 3. 进程和线程的关系⭐ 4. 线程私有的资源 (重要&#xff1a;面试) …

Map--08--CurrentHashMap 与 Hashtable的异同?

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 Map方法computeIfAbsent1.computeIfAbsent 方法的简介2.案例computeIfAbsent() Map方法computeIfAbsent computeIfAbsent方法是Java 8中引入的一种简化操作Map的方…

探索 AI 代理驱动的汽车保险索赔 RAG 管道。

这篇文章中&#xff0c;我探讨了最近的一项实验&#xff0c;旨在创建一个针对保险行业量身定制的 RAG 管道&#xff0c;专门用于处理汽车保险索赔&#xff0c;目的是尽可能减少处理时间。 我还展示了 Autogen AI Agents 的实施&#xff0c;通过代理交互和对样本汽车保险索赔文件…

李宏毅结构化学习 01

文章目录 一、结构化学校介绍二、线性模型 一、结构化学校介绍 训练时&#xff0c;F(x,y)是评估X与Y有多匹配&#xff0c;越匹配&#xff0c;R的值就越大。 测试时&#xff0c;确定F(x,y)后&#xff0c;给定一个x后&#xff0c;穷举所有y&#xff0c;使得F最大的那个就是 y ~ \…

k8s的配置

k8s的配置 拉取镜像&#xff0c;创建pod&#xff1a;从阿里云拉取 [rootk8s-master ~]# kubectl run nginx --imagenginx:latest [rootk8s-master ~]# kubectl get po -Aowide|grep nginx default nginx 0/1 ImagePullBackO…

【Linux取经之路】用户权限管理

目录 shell命令以及运行原理 Linux权限的概念 1、用户的概念 2、切换用户 Linux权限管理 1、文件访问者的分类 2、文件类型和访问权限 3、文件访问权限的修改 4、文件所有权的修改 5、设置权限掩码 6、用户提权 7、目录的权限 8、粘滞位 shell命令以及运行原理 Linu…

D - 88888888

设N有K位 则&#xff1a; p998244353&#xff0c;是质数&#xff0c;vn%p只需要求一下分母的逆元即可。 分母于p互质&#xff0c;满足飞马小定理&#xff0c;故可以用其求逆元。 再用一下这个结论就OK了 #include<bits/stdc.h> using namespace std; #define int long…

《JavaEE进阶》----16.<Mybatis简介、操作步骤、相关配置>

本篇博客讲记录&#xff1a; 1.回顾MySQL的JDBC操作 2..Mybatis简介、Mybatis操作数据库的步骤 3.Mybatis 相关日志的配置&#xff08;日志的配置、驼峰自动转换的配置&#xff09; 前言 之前学习应用分层时我们知道Web应用程序一般分为三层&#xff0c;Controller、Service、D…

使用Python从头开始创建PowerPoint演示文稿

目录 一、环境搭建与基础知识 1.1 环境搭建 1.2 基础知识 二、创建演示文稿对象 三、添加幻灯片 3.1 选择幻灯片布局 3.2 设置幻灯片内容 3.2.1 设置标题和副标题 3.2.2 添加文本内容 3.2.3 插入图片 3.2.4 插入图表 四、高级应用&#xff1a;批量生成演示文稿 4.…

太惨了!许家印前妻每个月只能花18万

文&#xff5c;琥珀食酒社 作者 | 积溪 许家印前妻被判了 我跟你说啊她真是太惨了 一个月只能取18万啊 你说这日子怎么过啊 买个包包都不够啊&#xff01; 大家都知道许皮带爆雷&#xff08;BL&#xff09;后 丁玉梅虽然和许皮带战略性离婚&#xff0c;逃到了英国 还把…

【计算机组成原理】浮点数的表示及IEEE 754规格化

&#x1f4e2;博客主页&#xff1a;https://blog.csdn.net/2301_779549673 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01; &#x1f4e2;本文由 JohnKi 原创&#xff0c;首发于 CSDN&#x1f649; &#x1f4e2;未来很长&#…

软件工程测试

1. 软件测试概述 通俗地说&#xff0c;软件测试是为了发现错误而执行程序的过程。 软件测试&#xff1a;根据软件开发各阶段的规格说明和程序的内部结构而精心设计一批测试用例&#xff08;即输入数据及其预期的输出结果&#xff09;&#xff0c;并利用这些测试用例去运行程序…