用matlab搭建一个简单的图像分类网络

news2025/4/8 19:28:23

文章目录

  • 1、数据集准备
  • 2、网络搭建
  • 3、训练网络
  • 4、测试神经网络
  • 5、进行预测
  • 6、完整代码

1、数据集准备

首先准备一个包含十个数字文件夹的DigitsData,每个数字文件夹里包含1000张对应这个数字的图片,图片的尺寸都是 28×28×1 像素的,如下图所示

在这里插入图片描述

matlab 中imageDatastore 函数会根据文件夹名称自动为图像进行分类标注。该数据集包含 10 个类别。

% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...  % 指定在加载数据时包含子文件夹中的图像
    LabelSource="foldernames");  % 使用子文件夹的名称作为图像的标签(自动分类)

% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels);  % 将 imds.Labels

将数据划分为训练集、验证集和测试集。使用 70% 的图像作为训练数据,15% 作为验证数据,15% 作为测试数据。指定使用 “randomized”(随机化),以便从每个类别中按指定比例随机分配图像到新的数据集中。
splitEachLabel 函数用于将图像数据存储对象划分成三个新的数据存储对象。

% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");
  • splitEachLabel:MATLAB 中的函数,用于根据每个标签(类别)分别划分图像数据集。这样可以确保每个类别在训练集、验证集和测试集中都有代表性。
  • imds:原始的图像数据存储对象,包含所有图像和对应的标签。
  • 0.7:表示将每个类别中 70% 的图像用于训练集
  • 0.15:表示每个类别中 15% 的图像用于验证集
  • 0.15:表示每个类别中 15% 的图像用于测试集
  • "randomized":表示在划分数据集时使用随机抽样,避免按文件顺序导致划分不均衡。
  • [imdsTrain, imdsValidation, imdsTest]:返回三个新的 imageDatastore 对象,分别代表:
    • imdsTrain:训练数据集
    • imdsValidation:验证数据集
    • imdsTest:测试数据集

2、网络搭建

这里,我们需要借用到matlab工具栏里APPS里的Deep Network Designer,如下图所示

在这里插入图片描述

在Deep Network Designer, 我们创建一个空白Designer画布

在这里插入图片描述

然后我们可以拖动相应的层到Designer里,并连接各个层,如下图所示

在这里插入图片描述

这里,我们只需要改一下输入层的InputSize就行,如下图

在这里插入图片描述

然后,我们可以检查这个网络可行不可行,通过Analyze按钮,就会得到这个网络的分析结果,如下图

在这里插入图片描述

没有错误,就可以通过Export按钮输出这个网络到Matlab工作区,这个网络被自动被命名为net_1。
在这里插入图片描述

3、训练网络

指定训练选项。不同选项的选择需要依赖实验分析(即通过反复试验和比较来确定最优配置)。

% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...  
    MaxEpochs = 4, ...  
    ValidationData = imdsValidation, ... 
    ValidationFrequency = 30, ...  
    Plots = "training-progress", ...  
    Metrics = "accuracy", ...  
    Verbose = false); 

trainingOptions 是 MATLAB 中用于设置神经网络训练参数的函数。

"sgdm" 是一种常用优化算法,适用于多数分类问题。

MaxEpochs=4 设置为 4 是为了快速试验,实际训练中可以设置更大,比如 10、20 甚至更多。

ValidationFrequency=30 表示每 30 次 mini-batch 后在验证集上评估一次性能,值越小越频繁,但也会增加验证的耗时。

Plots="training-progress" 是非常有用的调试和可视化工具,能帮助你观察训练是否收敛。

Verbose=false 适合在图形界面中查看结果时使用;如果希望看到文字日志,可以设置为 true

使用 trainnet 函数训练神经网络。由于目标是分类任务,因此使用交叉熵损失函数(cross-entropy loss)

% 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);
  • imdsTrain:训练数据集,是一个图像数据存储对象(imageDatastore),包含用于训练的图像和对应标签。
  • net_1:要训练的神经网络结构(可由 layerGraphdlnetwork 等方式定义的网络)。
  • "crossentropy":指定损失函数为交叉熵损失函数(cross-entropy loss),这是分类任务中最常用的损失函数,特别适用于多类分类问题。
  • options:训练选项,由前面设置的 trainingOptions 定义,包含训练轮数、验证数据、优化器、可视化等信息。

返回值:

  • net:训练完成后的神经网络,包含了优化后的权重和结构,可用于后续的预测或评估。

在这里插入图片描述

4、测试神经网络

使用 testnet 函数对神经网络进行测试。对于单标签分类任务,评估指标为准确率(accuracy),即预测正确的百分比。默认情况下,testnet 函数会在可用时自动使用 GPU。如果希望手动选择执行环境,可以使用 testnet 函数的 ExecutionEnvironment 参数进行设置。

% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");
  • net:已训练好的神经网络模型,是前面通过 trainnet 得到的结果。
  • imdsTest:测试数据集,是一个图像数据存储对象(imageDatastore),用于测试模型的性能。
  • "accuracy":评估指标,这里指定为准确率,即预测正确的样本数量占总样本数量的百分比。

返回值:

  • accuracy:一个介于 0 和 1 之间的小数,表示模型在测试集上的准确率。例如,accuracy = 0.93 表示模型在测试集中有 93% 的预测是正确的。

testnet 函数自动根据你的硬件情况选择在 CPU 还是 GPU 上运行。如果你想手动指定环境,比如使用 CPU,可以这样写:

accuracy = testnet(net, imdsTest, "accuracy", ExecutionEnvironment="cpu");

5、进行预测

使用 minibatchpredict 函数进行预测,并通过 scores2label 函数将预测得分转换为类别标签。默认情况下,如果有可用的 GPU,minibatchpredict 会自动使用 GPU 进行计算。

% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsValidation);

% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YValidation = scores2label(scores, classNames);

可视化部分预测结果:

% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);

% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);

% 创建一个新的图形窗口
figure
tiledlayout("flow")  % 使用自动流式布局排列子图(tiled layout)

% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9
    nexttile  % 在下一个网格位置准备绘图
    img = readimage(imdsTest, idx(i));  % 读取第 idx(i) 张图像
    imshow(img)  % 显示图像
    title("Predicted Class: " + string(YTest(idx(i))))  % 设置标题,显示预测类别
end

在这里插入图片描述

6、完整代码

% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...  % 指定在加载数据时包含子文件夹中的图像
    LabelSource="foldernames");  % 使用子文件夹的名称作为图像的标签(自动分类)

% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels);  % 将 imds.Labels


%%
% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");

% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ...  
    MaxEpochs = 4, ...  
    ValidationData = imdsValidation, ... 
    ValidationFrequency = 30, ...  
    Plots = "training-progress", ...  
    Metrics = "accuracy", ...  
    Verbose = false); 



% 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);

%%
% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");


%%
% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsTest);

% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YTest = scores2label(scores, classNames);


% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);

% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);

% 创建一个新的图形窗口
figure
tiledlayout("flow")  % 使用自动流式布局排列子图(tiled layout)

% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9
    nexttile  % 在下一个网格位置准备绘图
    img = readimage(imdsTest, idx(i));  % 读取第 idx(i) 张图像
    imshow(img)  % 显示图像
    title("Predicted Class: " + string(YTest(idx(i))))  % 设置标题,显示预测类别
end

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2328206.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AI4CODE】5 Trae 锤一个基于百度Amis的Crud应用

【AI4CODE】目录 【AI4CODE】1 Trae CN 锥安装配置与迁移 【AI4CODE】2 Trae 锤一个 To-Do-List 【AI4CODE】3 Trae 锤一个贪吃蛇的小游戏 【AI4CODE】4 Trae 锤一个数据搬运工的小应用 1 百度 Amis 简介 百度 Amis 是一个低代码前端框架,由百度开源。它通过 J…

npm webpack打包缓存 导致css引用地址未更新

问题如下: 测试环境配置: publicPath: /chat/,生产环境配置: publicPath: /,css中引用背景图片 background-image: url(/assets/images/calendar/arrow-left.png);先打包测试环境,观察打包后的css文件引用的背景图片地址 可以全…

ollama导入huggingface下载的大模型并量化

1. 导入GGUF 类型的模型 1.1 先在huggingface 下载需要ollama部署的大模型 1.2 编写modelfile 在ollama 里面输入 ollama show --modelfile <你有的模型名称> eg: ollama show --modelfile qwen2.5:latest修改其中的from 路径为自己的模型下载路径 FROM /Users/lzx/A…

Java 集合 Map Stream流

目录 集合遍历for each map案例 ​编辑 这种数组的遍历是【index】​编辑map排序【对象里重写compareTo​编辑map排序【匿名内部类lambda​编辑 stream流​编辑 ​编辑获取&#xff1a; map的键是set集合&#xff0c;获取方法map.keySet() map的值是collection 集合&…

【网络安全实验】PKI(证书服务)配置实验

目录 一、PKI相关概念 1.1 定义与核心功能 1.2 PKI 系统的组成 1.证书颁发机构&#xff08;CA, Certificate Authority&#xff09; 2.注册机构&#xff08;RA, Registration Authority&#xff09; 3.数字证书 1.3 PKI 的功能 1.4 PKI认证体系&#xff1a; 工作流程 …

【数据集】多视图文本数据集

多视图文本数据集指的是包含多个不同类型或来源的信息的文本数据集。不同视图可以来源于不同的数据模式&#xff08;如原始文本、元数据、网络结构等&#xff09;&#xff0c;或者不同的文本表示方法&#xff08;如 TF-IDF、词嵌入、主题分布等&#xff09;。这些数据集常用于多…

学透Spring Boot — 007. 七种配置方式及优先级

Spring Boot 提供很多种方式来加载配置&#xff0c;本文我们会用Tomcat的端口号作为例子&#xff0c;演示Spring Boot 常见的配置方式。 几种配置方式 使用默认配置 新建一个项目什么都不配置&#xff0c;Spring Boot会自动配置Tomcat端口号。 启动日志 TomcatWebServer :…

【youcans论文精读】弱监督深度检测网络(Weakly Supervised Deep Detection Networks)

欢迎关注『youcans论文精读』系列 本专栏内容和资源同步到 GitHub/youcans 【youcans论文精读】弱监督深度检测网络 WSDDN 0. 弱监督检测的开山之作0.1 论文简介0.2 WSDNN 的步骤0.3 摘要 1. 引言2. 相关工作3. 方法3.1 预训练网络3.2 弱监督深度检测网络3.3 WSDDN训练3.4 空间…

【服务日志链路追踪】

MDCInheritableThreadLocal和spring cloud sleuth 在微服务架构中&#xff0c;日志链路追踪&#xff08;Logback Distributed Tracing&#xff09; 是一个关键需求&#xff0c;主要用于跟踪请求在不同服务间的调用链路&#xff0c;便于排查问题。常见的实现方案有两种&#x…

【行测】判断推理:图形推理

> 作者&#xff1a;დ旧言~ > 座右铭&#xff1a;读不在三更五鼓&#xff0c;功只怕一曝十寒。 > 目标&#xff1a;掌握 图形推理 基本题型&#xff0c;并能运用到例题中。 > 毒鸡汤&#xff1a;有些事情&#xff0c;总是不明白&#xff0c;所以我不会坚持。早安! …

3D模型给可视化大屏带来了哪些创新,都涉及到哪些技术栈。

一、3D 模型给可视化大屏带来的创新 更直观的视觉体验 传统的可视化大屏主要以二维图表和图形的形式展示数据&#xff0c;虽然能够传达一定的信息&#xff0c;但对于复杂的场景和数据关系&#xff0c;往往难以直观地呈现。而 3D 模型可以将数据以三维立体的形式展示出来&#…

Unity HDRP管线用ShaderGraph还原Lit,方便做拓展;

里面唯一的重点就是判断有无这张复合图&#xff0c;我用的是颜色判断&#xff1a; float Tex TexCol.r*TexCol.g*TexCol.b*TexCol.a; if(Tex 1) { IsOrNot 1; } else { IsOrNot 0; } 其他的正常解码就行&#xff0c;对了法线贴图孔位记得设置成normal&#xff0c;不然的话…

绝缘升级 安全无忧 金能电力环保绝缘胶垫打造电力安全防护新标杆

在电力安全领域&#xff0c;一块看似普通的胶垫&#xff0c;却是守护工作人员生命安全的“第一道防线”。近年来&#xff0c;随着电网设备升级和环保要求趋严&#xff0c;传统绝缘胶垫有异味、易老化、绝缘性能不足等问题逐渐暴露。为此&#xff0c;金能电力凭借技术创新推出新…

Linux命令-iotop

iotop 命令 iotop 是一个用于实时监控磁盘 I/O 活动的工具&#xff0c;可以显示哪些进程正在使用磁盘资源。 参数 描述 –version 显示程序版本号并退出 -h, --help 显示此帮助消息并退出 -o, --only 仅显示实际进行 I/O 操作的进程或线程 -b, --batch 非交互模式&#xff0c;适…

QTableWidget 中insertRow(0)(头插)和 insertRow(rowCount())(尾插)的性能差异

一、目的 在 Qt 的 QTableWidget 中&#xff0c;insertRow(0) &#xff08;头插&#xff09;和 insertRow(rowCount())&#xff08;尾插&#xff09;在性能上存在显著差异。 二、QAbstractItemModel:: insertRows 原文解释 QAbstractItemModel Class | Qt Core 5.15.18 AI 解…

【万字总结】前端全方位性能优化指南(完结篇)——自适应优化系统、遗传算法调参、Service Worker智能降级方案

前言 自适应进化宣言 当监控网络精准定位病灶&#xff0c;真正的挑战浮出水面&#xff1a;系统能否像生物般自主进化&#xff1f; 五维感知——通过设备传感器实时捕获环境指纹&#xff08;如地铁隧道弱光环境自动切换省电渲染&#xff09; 基因调参——150个性能参数在遗传算…

不绕弯地解决文件编码问题,锟斤拷烫烫烫

安装python对应库 pip install chardet 检测文件编码 import chardet# 检测文件编码 file_path rC:\Users\AA\Desktop\log.log # 这里放文件和文件绝对路径 with open(file_path, rb) as f:raw_data f.read(100000) # 读取前10000个字节result chardet.detect(raw_data)e…

高密度任务下的挑战与破局:数字样机助力火箭发射提效提质

2025年4月1日12时&#xff0c;在酒泉卫星发射中心&#xff0c;长征二号丁运载火箭顺利升空&#xff0c;成功将一颗卫星互联网技术试验卫星送入预定轨道&#xff0c;发射任务圆满完成。这是长征二号丁火箭的第97次发射&#xff0c;也是长征系列火箭的第567次发射。 执行本次任务…

QT Quick(C++)跨平台应用程序项目实战教程 6 — 弹出框

目录 1. Popup组件介绍 2. 使用 上一章内容完成了音乐播放器程序的基本界面框架设计。本小节完成一个简单的功能。单击该播放器顶部菜单栏的“关于”按钮&#xff0c;弹出该程序的相关版本信息。我们将使用Qt Quick的Popup组件来实现。 1. Popup组件介绍 Qt 中的 Popup 组件…

KisFlow-Golang流式实时计算案例(四)-KisFlow在消息队列MQ中的应用

Golang框架实战-KisFlow流式计算框架专栏 Golang框架实战-KisFlow流式计算框架(1)-概述 Golang框架实战-KisFlow流式计算框架(2)-项目构建/基础模块-(上) Golang框架实战-KisFlow流式计算框架(3)-项目构建/基础模块-(下) Golang框架实战-KisFlow流式计算框架(4)-数据流 Golang框…