Matlab多输入单输出之倾斜手写数字识别

news2024/11/22 23:06:38

本文主要介绍使用matlab构建多输入单输出的网络架构,来实现倾斜的手写数字识别,使用concatenationLayer来拼接特征,实现网络输入多个特征。

1.加载训练数据

加载数据:手写数字的图像、真实数字标签和数字顺时针旋转的角度。

load DigitsDataTrain

网络的输入数据类型需要是datastore,使用 arrayDatastore 将三个普通矩阵变为datastore,最后再使用combine合并。

dsX1Train = arrayDatastore(XTrain,IterationDimension=4);
dsX2Train = arrayDatastore(anglesTrain);
dsTTrain = arrayDatastore(labelsTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);

显示20个随机训练图像:

numObservationsTrain = numel(labelsTrain);
idx = randperm(numObservationsTrain,20);

figure
tiledlayout("flow");
for i = 1:numel(idx)
    nexttile
    imshow(XTrain(:,:,:,idx(i)))
    title("Angle: " + anglesTrain(idx(i)))
end

图片

2.设计网络架构

设计如下的网络结构:

图片

  • 对于图像输入,指定一个大小与输入数据匹配的图像输入层。

  • 对于特征输入,指定一个大小与输入特征数量匹配的特征输入层。

  • 对于图像输入分支,进行卷积、批归一化和ReLU层块,其中卷积层有16个5×5滤波器。

  • 为了将批归一化层的输出转换为特征向量,需要用一个大小为50的全连接层。

  • 要将第一个全连接层的输出与特征输入连接起来,使用flatten layer将全连接层中的 "SSCB" (空间、空间、通道、批处理)输出展平,使其具有 "CB" 格式。

  • 沿第一维度(channel维度)将平坦层的输出与特征输入连接起来

  • 对于分类输出,包括一个输出大小与类数匹配的全连接层,然后是softmax层。

创建一个空的神经网络:

net = dlnetwork;

创建一个网络主分支,并将其添加到网络中:

[h,w,numChannels,numObservations] = size(XTrain);
numFeatures = 1;
classNames = categories(labelsTrain);
numClasses = numel(classNames);

imageInputSize = [h w numChannels];
filterSize = 5;
numFilters = 16;

layers = [
    imageInputLayer(imageInputSize,Normalization="none")
    convolution2dLayer(filterSize,numFilters)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(50)
    flattenLayer
    concatenationLayer(1,2,Name="cat")
    fullyConnectedLayer(numClasses)
    softmaxLayer];

net = addLayers(net,layers);

将feature input layer添加到网络中,并将其连接到 concatenation layer的第二个输入:

featInput = featureInputLayer(numFeatures,Name="features");
net = addLayers(net,featInput);
net = connectLayers(net,"features","cat/in2");

在绘图中可视化网络:

figure
plot(net)

3.训练网络

使用SGDM优化器进行训练,训练15个epochs,以0.01的学习率进行训练,在图表中显示训练进度并监控accuracy指标,不显示详细输出。

options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=0);

使用 trainnet 函数训练神经网络,对于分类使用交叉熵损失。

net = trainnet(dsTrain,net,"crossentropy",options);

图片

4.测试网络

通过将测试集上的预测与真实标签进行比较来测试网络的分类准确性,加载测试数据:

load DigitsDataTest

使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。

scores = minibatchpredict(net,XTest,anglesTest);
YTest = scores2label(scores,classNames);

在混淆图中可视化预测:

figure
confusionchart(labelsTest,YTest)

图片

评估分类准确性:

accuracy = mean(YTest == labelsTest)

accuracy = 0.9878

查看一些预测的图像:

idx = randperm(size(XTest,4),9);
figure
tiledlayout(3,3)
for i = 1:9
    nexttile
    I = XTest(:,:,:,idx(i));
    imshow(I)

    label = string(YTest(idx(i)));
    title("Predicted Label: " + label)
end

图片

5.不用角度特征训练和测试网络

% 网络设计
net_without_feature = dlnetwork;
layers = [
    imageInputLayer(imageInputSize,Normalization="none")
    convolution2dLayer(filterSize,numFilters)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

net_without_feature = addLayers(net_without_feature,layers);
% 网络训练
options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=0);

dsTrain_without_feature = combine(dsX1Train,dsTTrain);

net_without_feature = trainnet(dsTrain_without_feature,net_without_feature,"crossentropy",options);

图片

% 在混淆矩阵中可视化预测。
scores = minibatchpredict(net_without_feature,XTest);
YTest = scores2label(scores,classNames);
figure
confusionchart(labelsTest,YTest)

图片

% 评估分类准确性。
accuracy = mean(YTest == labelsTest)

accuracy = 0.9858

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2245603.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Nacos实现IP动态黑白名单过滤

一些恶意用户(可能是黑客、爬虫、DDoS 攻击者)可能频繁请求服务器资源,导致资源占用过高。因此我们需要一定的手段实时阻止可疑或恶意的用户,减少攻击风险。 本次练习使用到的是Nacos配合布隆过滤器实现动态IP黑白名单过滤 文章…

如何在Word文件中设置水印以及如何禁止修改水印

在日常办公和学习中,我们经常需要在Word文档中设置水印,以保护文件的版权或标明文件的机密性。水印可以是文字形式,也可以是图片形式,能够灵活地适应不同的需求。但仅仅设置水印是不够的,有时我们还需要确保水印不被随…

测试工程师如何在面试中脱颖而出

目录 1.平时工作中是怎么去测的? 2.B/S架构和C/S架构区别 3.B/S架构的系统从哪些点去测? 4.你为什么能够做测试这一行?(根据个人情况分析理解) 5.你认为测试的目的是什么? 6.软件测试的流程&#xff…

jenkins的安装(War包安装)

‌Jenkins是一个开源的持续集成工具,基于Java开发,主要用于监控持续的软件版本发布和测试项目。‌ 它提供了一个开放易用的平台,使软件项目能够实现持续集成。Jenkins的功能包括持续的软件版本发布和测试项目,以及监控外部调用执行…

无线感知会议系列【15】DPSense-2

接: 无线感知会议系列【15】DPSense-1 目录: 实验 讨论 结论 附录 一 实验 在本节中,我们通过全面的实验验证了所提出的DPSense系统的有效性。首先,我们将我们的方法与三种最先进的技术进行了比较。然后&#xff0c…

AI编程入门指南002:API、数据库和应用部署

进阶概念教程:API、数据库和应用部署 在学习了编程的基础概念后,我们将进入更高级的内容。本文将详细介绍API、数据库和应用部署三个进阶概念,并通过丰富的示例和形象的说明帮助你更好地理解这些内容。 1. API(应用程序接口&#…

Docker3:docker基础1

欢迎来到“雪碧聊技术”CSDN博客! 在这里,您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者,还是具有一定经验的开发者,相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导,我将…

《Django 5 By Example》阅读笔记:p645-p650

《Django 5 By Example》学习第8天,p645-p650总结,总计6页。 一、技术总结 1.django-rest-framework (1)serializer p648, Serializer: Provides serialization for normal Python class instances。Serializer又细分为Serializer, ModelSerializer,…

【机器学习】回归模型(线性回归+逻辑回归)原理详解

线性回归 Linear Regression 1 概述 线性回归类似高中的线性规划题目。线性回归要做的是就是找到一个数学公式能相对较完美地把所有自变量组合(加减乘除)起来,得到的结果和目标接近。 线性回归分为一元线性回归和多元线性回归。 2 一元线…

【大模型推理】vLLM 源码学习

强烈推荐 https://zhuanlan.zhihu.com/p/680153425 sequnceGroup 存储了相同的prompt对应的不同的sequence, 所以用字典存储 同一个Sequence可能占据多个逻辑Block, 所以在Sequence 中用列表存储 同一个block 要维护tokens_id 列表, 需要添加操作。 还需要判断blo…

FIFO和LRU算法实现操作系统中主存管理

FIFO&#xff0c;用数组实现 1和2都是使用nextReplace实现新页面位置的更新 1、不精确时间&#xff1a;用ctime输出运行时间都是0.00秒 #include <iostream> #include <iomanip> #include<ctime>//用于计算时间 using namespace std;// 页访问顺序 int pa…

【Ubuntu24.04】VirtualBox安装ubuntu-live-server24.04

目录 0 背景1 下载镜像2 安装虚拟机3 安装UbuntuServer24.044 配置基本环境5 总结0 背景 有了远程连接工具之后,似乎作为服务器的Ubuntu24.04桌面版有点备受冷落了,桌面版的Ubuntu24.04的优势是图形化桌面,是作为一个日常工作的系统来用的,就像Windows,如果要作为服务器来…

《SpringBoot、Vue 组装exe与套壳保姆级教学》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

Flowable第一篇、快速上手(Flowable安装、配置、集成)

目录 Flowable 概述Flowable的安装与配置 2.1. FlowableUI安装 2.2. Flowable BPMN插件下载 2.3 集成Spring Boot流程审核操作 3.3 简单流程部署 3.4 启动流程实例 3.5 流程审批 一、Flowable 概述 Flowable是一个轻量级、高效可扩展的工作流和业务流程管理&#xff08;BPM&…

Docker搭建有UI的私有镜像仓库

Docker搭建有UI的私有镜像仓库 一、使用这个docker-compose.yml文件&#xff1a; version: 3services:registry-ui:image: joxit/docker-registry-ui:2.5.7-debianrestart: alwaysports:- 81:80environment:- SINGLE_REGISTRYtrue- REGISTRY_TITLEAtt Docker Registry UI- DE…

容器安全检测和渗透测试工具

《Java代码审计》http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247484219&idx1&sn73564e316a4c9794019f15dd6b3ba9f6&chksmc0e47a67f793f371e9f6a4fbc06e7929cb1480b7320fae34c32563307df3a28aca49d1a4addd&scene21#wechat_redirect Docker-bench-…

Day10_CSS过度动画

Day10_CSS过度动画 背景 : PC和APP项目我们已经开发完毕, 但是再真正开发的时候有些有些简易的动态效果我们可以使用CSS完成 ; 本节课我们来使用CSS完成基础的动画效果 今日学习目标 CSS3过度CSS3平面动态效果CSS3动画效果案例 1. CSS3过渡 ​ 含义 :过渡指的是元素从一种…

iOS应用网络安全之HTTPS

移动互联网开发中iOS应用的网络安全问题往往被大部分开发者忽略, iOS9和OS X 10.11开始Apple也默认提高了安全配置和要求. 本文以iOS平台App开发中对后台数据接口的安全通信进行解析和加固方法的分析. 1. HTTPS/SSL的基本原理 安全套接字层 (Secure Socket Layer, SSL) 是用来…

excel版数独游戏(已完成)

前段时间一个朋友帮那小孩解数独游戏&#xff0c;让我帮解&#xff0c;我看他用电子表格做&#xff0c;只能显示&#xff0c;不能显示重复&#xff0c;也没有协助解题功能&#xff0c;于是我说帮你做个电子表格版的“解题助手”吧&#xff0c;不能直接解题&#xff0c;但该有的…

金融数据中心容灾“大咖说” | 美创科技赋能“灾备一体化”建设

中国人民银行发布的《金融数据中心容灾建设指引》&#xff08;JR/T 0264—2024&#xff09;已于2024年7月29日正式实施。这一金融行业标准对金融数据中心容灾建设中的“组织保障、需求分析、体系规划、建设要求、运维管理”进行了规范和指导。面对不断增加的各类网络、业务、应…