1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)

news2024/11/26 9:47:12

1、卷积神经网络的手写数字旋转角度预测原理及流程

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下是基于MATLAB的手写数字旋转角度预测的原理和流程:

原理:

  1. 数据准备:首先,准备一个包含手写数字图像和其对应标签(即旋转角度)的数据集。这些图像可以是MNIST数据集的手写数字。

  2. 模型建立:构建一个CNN模型,包括卷积层、池化层、全连接层等,来学习手写数字图像的特征并预测它们的旋转角度。

  3. 训练模型:利用准备好的训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数以最小化预测与真实标签之间的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算模型的准确率或其他性能指标,以评估其在预测手写数字旋转角度方面的性能。

流程:

  1. 加载数据集:在MATLAB中加载手写数字图像数据集,并对图像进行预处理和标签处理,以便输入到CNN模型中。

  2. 构建CNN模型:使用MATLAB深度学习工具箱中的函数(如convolution2dLayermaxPooling2dLayerfullyConnectedLayerclassificationLayer)构建一个适合手写数字旋转角度预测的CNN模型。

  3. 定义训练选项:设置训练选项,包括优化器类型、学习率、最大训练轮数等。

  4. 训练模型:使用训练数据集对CNN模型进行训练,通过调用trainNetwork函数并传入训练数据和训练选项来完成训练过程。

  5. 评估模型:使用测试数据集对训练好的模型进行评估,计算准确率等性能指标。

  6. 预测手写数字的旋转角度:最后,使用训练好的模型对新的手写数字图像进行预测,得到其旋转角度的预测结果。

这是基于卷积神经网络的手写数字旋转角度预测的基本原理和流程。

2、卷积神经网络的手写数字旋转角度预测案例说明

1)解决问题

卷积神经网络来预测手写数字的旋转角度

2)技术方案

回归任务涉及预测连续数值而不是离散类标签,回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。

3、加载数据

1)数据说明

数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。

2)加载数据代码

说明:变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。

load DigitsDataTrain
load DigitsDataTest

3)显示训练集代码

numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);

 视图效果

06f6163700a5438b802835e39b5c0504.png

4)数据集划分代码

说明:使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。

[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);

XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);

XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);

4、检查数据归一化

1)归一化说明

训练神经网络时,确保数据在网络的所有阶段均归一化

对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.

数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离

归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1

2)绘制响应的分布代码

说明:响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。

figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")

视图效果 

b293b09f497e445ea96b6b4fd31045df.png

5、定义神经网络架构

1)神经网络架构说明

对于图像输入,指定一个图像输入层。

指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。

在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。

在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。

2)神经网络架构代码

numResponses = 1;

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,Padding="same")
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,Stride=2)
    convolution2dLayer(3,16,Padding="same")
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,Stride=2)
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numResponses)];

6、指定训练选项

1)指定训练选项说明

使用Experiment Manager。

将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。

通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。

在图中显示训练进度并监控均方根误差。

2)指定训练选项代码

miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);

options = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ...
    InitialLearnRate=1e-3, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.1, ...
    LearnRateDropPeriod=20, ...
    Shuffle="every-epoch", ...
    ValidationData={XTest,anglesTest}, ...
    ValidationFrequency=validationFrequency, ...
    Plots="training-progress", ...
    Metrics="rmse", ...
    Verbose=false);

7、训练神经网络

1)训练神经网络说明

使用 trainnet 函数训练神经网络。

对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。

2)训练神经网络代码

net = trainnet(XTrain,anglesTrain,layers,"mse",options);

视图效果

 eaa6e5e1078e4fd8bd4d9c5fe0cf06f0.png

8、测试网络

1)测试网络说明

基于测试数据评估准确度来测试网络性能。

使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。

2)测试网络代码

YTest = minibatchpredict(net,XTest);

3)计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异 

predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))

 4)散点图中可视化预测。绘制预测值对真实值的图。

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],"y--")

视图效果 

a35e3dd73632423a92451bfbd0ae7b66.png

9、使用新数据进行预测

1)测试说明

使用 predict 函数并使用神经网络对第一个测试图像进行预测

2)测试代码

X = XTest(:,:,:,1);
if canUseGPU
    X = gpuArray(X);
end
Y = predict(net,X)

10、总结

基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题,通过使用MATLAB深度学习工具箱可以比较方便地实现。下面是对这一任务的总结:

总结要点:

  1. 数据准备:准备包含手写数字图像和对应旋转角度标签的数据集,如MNIST数据集。

  2. 模型建立:构建卷积神经网络(CNN)模型,通过卷积层、池化层、全连接层等结构来学习手写数字图像的特征和预测旋转角度。

  3. 训练模型:使用训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数,最小化预测与真实标签的误差。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,计算准确率或其他性能指标,评定模型在预测旋转角度上的性能。

实现流程:

  1. 数据加载和预处理:加载手写数字图像数据集,对图像进行预处理(如缩放、归一化)并提取对应的旋转角度标签。

  2. CNN模型构建:使用MATLAB深度学习工具箱中的函数构建CNN模型,包括卷积层、池化层、全连接层,并适当选择激活函数。

  3. 训练模型:定义训练选项,选择优化器和学习率等参数,使用训练数据集对CNN模型进行训练。

  4. 模型评估:使用测试数据集对训练好的模型进行评估,检验其在预测手写数字旋转角度的准确性。

  5. 预测和应用:最后,使用训练好的模型对新的手写数字图像进行预测,实现手写数字旋转角度的自动识别和预测。

通过以上流程和总结,您可以利用MATLAB深度学习工具箱来实现基于卷积神经网络的手写数字旋转角度预测任务。

11、源代码

代码

%% 基于卷积神经网络的手写数字旋转角度预测
%卷积神经网络来预测手写数字的旋转角度
%回归任务涉及预测连续数值而不是离散类标签
%回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。

%% 加载数据
%数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
%变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。

load DigitsDataTrain
load DigitsDataTest

%显示训练集
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);

%数据集划分
%使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);

XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);

XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);

%% 检查数据归一化
%训练神经网络时,确保数据在网络的所有阶段均归一化。
%对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
%数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
%归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1

%绘制响应的分布。
% 响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")

%%  定义神经网络架构
%对于图像输入,指定一个图像输入层。
%指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
%在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
%在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
numResponses = 1;

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,Padding="same")
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,Stride=2)
    convolution2dLayer(3,16,Padding="same")
    batchNormalizationLayer
    reluLayer
    averagePooling2dLayer(2,Stride=2)
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numResponses)];
%% 指定训练选项
%使用Experiment Manager。
%将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
%通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
%在图中显示训练进度并监控均方根误差。

miniBatchSize  = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);

options = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ...
    InitialLearnRate=1e-3, ...
    LearnRateSchedule="piecewise", ...
    LearnRateDropFactor=0.1, ...
    LearnRateDropPeriod=20, ...
    Shuffle="every-epoch", ...
    ValidationData={XTest,anglesTest}, ...
    ValidationFrequency=validationFrequency, ...
    Plots="training-progress", ...
    Metrics="rmse", ...
    Verbose=false);
%% 训练神经网络
%使用 trainnet 函数训练神经网络。
%对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
%% 测试网络
%基于测试数据评估准确度来测试网络性能。
%使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
YTest = minibatchpredict(net,XTest);
%计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异。
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
%散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],"y--")

%% 使用新数据进行预测
%使用 predict 函数并使用神经网络对第一个测试图像进行预测
X = XTest(:,:,:,1);
if canUseGPU
    X = gpuArray(X);
end
Y = predict(net,X)

工程文件

https://download.csdn.net/download/XU157303764/89494539

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1941372.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

操作线程的方法

文章目录 前言一、线程的生命周期二、线程的操作方法 1.休眠2.加入3.中断4.礼让总结 前言 将线程看作一个生命的开始和结束,更好理解它各个状态的变化。同时该文会介绍操作线程的主要方法来控制线程的生命周期。这些方法的使用和线程生命周期的变化是密切相关的。 一…

甄选范文“论面向方面的编程技术及其应”,软考高级论文,系统架构设计师论文

论文真题 针对应用开发所面临的规模不断扩大、复杂度不断提升的问题,面向方面的编程(Aspect Oriented Programming,AOP)技术提供了一种有效的程序开发方法。为了理解和完成一个复杂的程序,通常要把程序进行功能划分和封装。一般系统中的某些通用功能,如安全性、持续性、日…

Intellij IDEA 的Plugins加载不出来的解决方法

一、点开插件---右上角设置---HTTP代理设置 二、勾选自动检测代理设置 输入url: https://plugins.jetbrains.com/ 配置完成后,点击确定。 然后点击检查连接,再一次输入那个URL,一般来说可以连接成功了 然后 重启IDEA以刷新缓…

详解数据结构之二叉树(堆)

详解数据结构之二叉树(堆) 树 树的概念 树是一个非线性结构的数据结构,它是由 n(n>0)个有限节点组成的一个具有层次关系的集合,它的外观形似一颗倒挂着的树,根朝上,叶朝下,所以称呼为树。每颗子树的根节点有且只…

7. 聚类算法 KMeans

聚类算法 KMeans 1. 应用:大数据杀熟2. 迭代法3. 代码 1. 应用:大数据杀熟 618、双十一,平台要对用户进行分类:用户: 脑残粉(不降价,或者涨点价)墙头草(给点小优惠券&am…

二叉树基础及实现(一)

目录: 一. 树的基本概念 二. 二叉树概念及特性 三. 二叉树的基本操作 一. 树的基本概念: 1 概念 : 树是一种非线性的数据结构,它是由n(n>0 )个有限结点组成一个具有层次关系的集合。 把它叫做树是因…

数据结构之初始二叉树(4)

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程(ಥ_ಥ)-CSDN博客 所属专栏:数据结构(Java版) 二叉树的基本操作 二叉树的相关刷题(上)通过上篇文章的学习,我们…

基于密钥的身份验证(Linux-Linux)

A主机: 1、生成密钥对 [rootservera ~]# ssh-keygen查看公钥 注:id_rsa为私钥(证书),id_rsa.pub为公钥 2、注册公钥到服务器 [rootservera ~]# ssh-copy-id root172.25.250.106 查看.ssh 3、使用密钥连接服务器 #…

ViT(Vision Transformer)网络结构详解

本文在transformer的基础上对ViT进行讲解,transformer相关部分可以看我另一篇博客(transformer中对于QKV的个人理解-CSDN博客)。 一、网络结构概览 上图展示了Vision Transformer (ViT) 的基本架构,我按照运行顺序分为三个板块进…

配置web服务器

当访问网站www.haha.com时显示:haha;当访问网站www.xixi.com/secret/显示:this is secret 第一步,配置一个新的IP 确认后 esc返回 第二步:重启ens160 第三步:创建目录,并且在文件内写入内容 第…

英福康INFICON UL1000检漏仪介绍PPT

英福康INFICON UL1000检漏仪介绍PPT

【周记】2024暑期集训第二周(未完待续)

文章目录 日常刷题记录合并果子题目解析算法思路代码实现 中位数题目解析算法思路代码实现 C学习笔记队列queue双端队列 deque优先队列 priority_queue定义常见操作 upper_bound 日常刷题记录 合并果子 题目解析 有一堆果子,每次可以将两小堆合并,耗费…

verilog行为建模(四):过程赋值

目录 1.两类过程赋值2.阻塞与非阻塞赋值语句行为差别举例13.阻塞与非阻塞赋值语句行为差别举例24.阻塞与非阻塞赋值语句行为差别举例35.举例4:非阻塞赋值语句中延时在左边和右边的差别 微信公众号获取更多FPGA相关源码: 1.两类过程赋值 阻塞过程赋值执…

漫威争锋Marvel Rivals测试搜不到 漫威争锋Marvel Rivals怎么搜

漫威争锋,一款今年即将上线的6v6的fps游戏,漫威争锋Marvel Rivals一经公布就吸引了广大玩家的兴趣。玩家将在游戏中扮演一名名经典且有趣的漫威英雄,与敌人展开对决。而且该游戏中有着很多的漫威英雄供我们挑选使用,有着很多英雄的…

【数据结构】排序算法——Lessen1

Hi~!这里是奋斗的小羊,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 💥💥个人主页:奋斗的小羊 💥💥所属专栏:C语言 🚀本系列文章为个人学习…

音乐播放器的优雅之选,黑金ONIX Overture XM5,更悦耳的音质体验

如今想要随时沉浸式的体验高品质的数字音乐资源,一款简单好用的音乐播放器必不可少,多年来在音乐爱好者的心中的经典品牌屈指可数,英国品牌ONIX算是一个,其Horizon系列以优雅的设计以及出众的品质,收获了很多忠实粉丝。…

OpenAI发布迷你AI模型GPT-4o mini

本心、输入输出、结果 文章目录 OpenAI发布迷你AI模型GPT-4o mini前言OpenAI发布迷你AI模型GPT-4o mini英伟达联合发布 Mistral-NeMo AI 模型:120 亿参数、上下文窗口 12.8 万个 tokenOpenAI发布迷你AI模型GPT-4o mini 编辑 | 简简单单 Online zuozuo 地址 | https://blog.csd…

【ADRC笔记】LESO-Wb

公式推导(bilibili) 一阶ESO 二阶ESO 二阶自抗扰控制器基本原理 选取状态变量 观测器收敛性推导 wo 观测器带宽

C语言·函数(超详细系列·全面总结)

前言:Hello大家好😘,我是心跳sy,为了更好地形成一个学习c语言的体系,最近将会更新关于c语言语法基础的知识,今天更新一下函数的知识点,我们一起来看看吧! 目录 一、函数是什么 &a…

HTTPServer改进思路1

Nginx源码思考项目改进 架构模式 事件驱动架构(EDA)用于处理大量并发连接和IO操作 优点:高效处理大量并发请求,减少线程切换和阻塞调用技术实现:直接使用EPOLL,参考Node.js的http服务器 网络通信 协议:HTT…