[matlab]MATLAB实现MLP多层感知机minist手写识别预测

news2024/9/20 10:31:30

【测试环境】

matlab2023a

【源码文件截图】

【实现部分代码】

mlp_test.m




%% MLP 2-layer to test XOR
clear;
clc;

Mode = 'MNIST'
%Mode = 'XOR'

if (strcmp(Mode,'MNIST'))
    % Load the digits into workspace (MNIST Test, from
    % http://yann.lecun.com/exdb/mnist/)
    num_train = 1000;
    [train_IMG,train_labels,test_IMG,test_labels] = readMNIST(num_train);
    input =cell(num_train,1);
    output =cell(num_train,1);
    
    test_input=cell(length(test_IMG),1);
    test_output=cell(length(test_labels),1);
    
    for i=1:num_train
        %input_img = double(train_IMG{i});
        %Pre processing - prewitt
        input_img = edge(train_IMG{i},'prewitt');
        
             
        [width height] = size(input_img);
        img_vec = reshape(input_img,1,width*height);
        input{i}=double(img_vec);
        
        labels_arr = zeros(1,10);
        labels_arr(train_labels(i)+1)=1;
        output{i} = labels_arr;
        
    end
    
    for i=1:length(test_input)
        %input_img = double(test_IMG{i});
        %Pre processing - prewitt
        input_img = edge(test_IMG{i},'prewitt');
        
        
        
        [width height] = size(input_img);
        img_vec = reshape(input_img,1,width*height);
        test_input{i} = double(img_vec);
        
        labels_arr = zeros(1,10);
        labels_arr(test_labels(i)+1)=1;
        test_output{i} = labels_arr;
        
        
        
    end
elseif (strcmp(Mode,'XOR'))
    
    num_train = 4;
    input =cell(num_train,1);
    output =cell(num_train,1);
    input{1} = [0 0];
    input{2} = [0 1];
    input{3} = [1 0];
    input{4} = [1 1];
    output{1} = [0];
    output{2} = [1];
    output{3} = [1];
    output{4} = [0];
    
end

if (length(input)~=length(output))
    error('len_input does not equal to len_output');
end



%% Determine # of nodes in hidden layer & output layer
num_node_il = length(input{1});
%num_node_hl = [num_node_il*2];
num_node_hl = [num_node_il];
num_node_ol = length(output{1});

set_node =[num_node_il num_node_hl num_node_ol];

%% Init. template (random)
rand('state',sum(100*clock));

num_layer = length(set_node);
W=cell(num_layer-1,1);
B=cell(num_layer-1);

for i=1:num_layer-1

	%% [Xavier10] shows that the interval ~ from https://deeplearning.net/tutorial/mlp.html
		
    min_W = -4*sqrt(6/(set_node(i)+set_node(i+1)));
    max_W = 4*sqrt(6/(set_node(i)+set_node(i+1)));
	
	W{i} = min_W+(2*max_W).*rand(set_node(i),set_node(i+1));
    B{i} = rand(1,set_node(i+1));
end

%% Learning coeff = 0.7 & Iteration = 10

% 141108, Success rate = 0.725
%lrn_rate = 0.3;
%max_iter = 100;


lrn_rate = 0.3;
max_iter = 50;%最大迭代次数,越大训练时间越长精度越高



tic


Act=cell(num_layer,1);
Err=cell(num_layer-1,1);


err_trace=[];




for index_inter= 1:max_iter
    
    if mod(index_inter,50) ==0
        index_inter
    end
       
    Act_trace=[];
    Train_trace = [];
    for j= 1:num_train
        
        P = randperm(num_train);
        train_input = input{P(j)};
        train_output = output{P(j)};
        
        
        % Forward Propagation
        [Act]   		=   FP(train_input,Act,W,B,num_layer);
        % Backward Propagation & Template update
        [W,B,Err]       =   BP(train_output,Act,W,B,num_layer,lrn_rate,Err);
        
        % Debug
        
        
        
        [row,col]=find(Act{end}==max(Act{end}));
        Act_trace(end+1)=col-1;
        
        [row,col]=find(train_output==max(train_output));
        Train_trace(end+1)=col-1;
        
    end   

    All_arr(index_inter).act = Act_trace;
    All_arr(index_inter).train=Train_trace;
    All_arr(index_inter).err = Act_trace-Train_trace;
end

toc

save
disp('Training Ends')


if (strcmp(Mode,'XOR'))
    grid = [0:0.01:1];
    Z=-1*ones(length(grid),length(grid));
    for i=1:length(grid)
        for j=1:length(grid)
            test = [grid(i) grid(j)];
            Act_new = FP(test,Act,W,B,num_layer);
            Z(i,j) = Act_new{3};
            
        end
    end
	[X,Y] = meshgrid(grid);
	mesh(X,Y,Z)

elseif (strcmp(Mode,'MNIST'))
    Guess_arr = [];
    for i=1:length(test_input)
        
        [guess_result] = FP(test_input{i},Act,W,B,num_layer);
        
        [row,col]=find(guess_result{end}==max(guess_result{end}));
        Guess_arr(end+1)=col-1;
        
    end
end

Z=zeros(10,10);

for i =1:length(test_labels)
    Z(Guess_arr(i)+1,test_labels(i)+1)=Z(Guess_arr(i)+1,test_labels(i)+1)+1;
end



Abs_err = Guess_arr-double(test_labels)';
success_rate = sum(Guess_arr-double(test_labels)'==0)/1000
%plot(abs(Guess_arr-double(test_labels)'))

%figure(1);scatter((Guess_arr*10),test_labels)
%figure(2);plot(err_trace);

 运行结果:

【源码下载地址】 https://download.csdn.net/download/FL1623863129/88600486

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2072916.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

el-date-picker 设置值输出格式

el-date-picker 设置值输出格式 现象 在请求后端的时候因为日期格式不对导致后端请求报错 看到时间默认的格式为:2024-08-13T16:00:00.000Z 这个时间如果需要转换成时间格式还是比较费劲的 解决方案 方式1-对字符串进行处理 formatDate(date) {// 格式化为 YY…

Java:常用API:Math类,System类

文章目录 Math常用方法代码 System类常用方法代码 RunTime类常用方法代码 黑马学习笔记 alt回车抛出异常 Math 常用方法 这是static方法,直接Math打点调方法 代码 package com.zhang.math;/*** Author: ggdpzhk* CreateTime: 2024-08-25* Math工具类的基本用法…

build.grade.kts 如何定义插件及插件扩展

定义插件和应用插件 在build.gradle.kts文件内 这里要注意的是&#xff0c;最后一行的Project扩展函数名必须要和上面apply方法里面create的参数一致&#xff0c;然后project扩展函数定义之前必须先apply<>()也就是先使用apply让plugin apply方法运行起来&#xff0c;才…

C++函数调用栈从何而来

竹杖芒鞋轻胜马,谁怕?一蓑烟雨任平生~ 个人主页&#xff1a; rainInSunny | 个人专栏&#xff1a; C那些事儿、 Qt那些事儿 文章目录 写在前面原理综述x86架构函数调用栈分析如何获取rbp寄存器的值总结 写在前面 程序员对函数调用栈是再熟悉不过了&#xff0c;无论是使用IDE…

printk的原理及使用

内核驱动调试的方法&#xff0c;先从我最常用的printk的使用方法开始讲起, printk在内核源码中用来记录日志信息的函数&#xff0c;方便我们调试追踪代码&#xff0c;只能在内核源码范围内使用。 本篇内核采用5.10版本。 很多内核开发者最喜欢的调试工具之一是printk(),printk(…

分享一个基于python新闻订阅与分享平台flask新闻发布系统(源码、调试、LW、开题、PPT)

&#x1f495;&#x1f495;作者&#xff1a;计算机源码社 &#x1f495;&#x1f495;个人简介&#xff1a;本人 八年开发经验&#xff0c;擅长Java、Python、PHP、.NET、Node.js、Android、微信小程序、爬虫、大数据、机器学习等&#xff0c;大家有这一块的问题可以一起交流&…

【目标检测】AGMF-Net:遥感目标检测的无注意力全局多尺度融合网络

《Attention-Free Global Multiscale Fusion Network for Remote Sensing Object Detection》 遥感目标检测的无注意力全局多尺度融合网络 原文&#xff1a;https://ieeexplore.ieee.org/document/10371366 摘要 遥感目标检测&#xff08;RSOD&#xff09;在复杂背景和小目标…

设计模式篇(DesignPattern - 前置知识 七大原则)(持续更新调整)

目录 前置知识 一、什么是设计模式 二、设计模式的目的 七大原则 原则一&#xff1a;单一职责原则 一、案例一&#xff1a;交通工具问题 1. 问题分析 2. 解决思路 2.1 类级别单一职责 2.2 方法级别单一职责 3. 知识小结 二、案例二&#xff1a;待更新 原则二&…

本·阿弗莱克在与詹妮弗·洛佩兹离婚期间与孩子塞拉菲娜共度时光

在詹妮弗洛佩兹提出离婚申请期间&#xff0c;本阿弗莱克被发现与塞拉菲娜阿弗莱克一起在加州观看电影。 本阿弗莱克似乎将重心放在家庭时间上&#xff0c;最近有人拍到他带着孩子塞拉菲娜阿弗莱克在一起。此前&#xff0c;他的妻子詹妮弗洛佩兹 于 8 月 20 日星期二提出离婚。 …

小黄鸟九宫格切图丨教你如何将图片九宫格切图_照片分割成9张工具

图片九宫格怎么弄&#xff1f;怎么把1张图片切割称九宫图&#xff1f;如何将一张照片切成九宫格 微博九宫图怎么做&#xff1f;你还不知道电脑上如何做微博九宫格图片? 今天用小黄鸟九宫格切割工具&#xff0c;手把手教你,搞定九宫格切图 小黄九宫格切图丨小黄鸟教你如何九宫…

如何使用ssm实现基于web的药品管理系统+vue

TOC ssm175基于web的药品管理系统vue 第1章 绪论 1.1 课题背景 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。所以各行业&…

五、Centos7-安装Jenkins--这篇废了

克隆了一个base的虚拟机&#xff0c;用来安装Jenkins 2023年11月&#xff0c;Jenkins不支持centos7了。我们只是学习用&#xff0c;先看看吧。 &#xff08; 另一个人用别的操作系统安装的jenkins&#xff0c;可以参考 版权声明&#xff1a;本文为博主原创文章&#xff0c;…

js第五天-对象

object let obj {uname: pink,age: 18,gender: w} 增 对象名.属性新值 这个和cpp不一样&#xff0c;可以在大括号外面新增属性 <script>let obj {uname: pink,age: 18,gender: w}obj.hobby footballconsole.log(obj);</script>删 delete delete obj.gender …

Spring Boot整合MyBatis-Plus的详细讲解

MyBatis Plus&#xff08;简称MP&#xff09;是一个在MyBatis基础上进行增强的工具&#xff0c;它保留了MyBatis的所有特性&#xff0c;并通过提供额外的功能和简化操作来提高开发效率。以下是对MyBatis Plus的详细介绍&#xff1a; 一、基本概述 定义&#xff1a;MyBatis Plu…

【MATLAB学习笔记】绘图——设置次刻度线的数量、设置刻度线的宽度(粗细)和长度

目录 前言设置次刻度线数量函数示例基本绘图设置次刻度线数量函数的使用 设置刻度线的长度设置刻度线和轴线的宽度总代码总结 前言 在MATLAB中&#xff0c;将XMinorTicktrue或者YMinorTicktrue设置为true可以很方便地设置X轴或者Y轴次刻度线&#xff0c;但是次刻度线的数量是MA…

代码随想录DAY25 - 回溯算法 - 08/24

目录 非递减子序列 题干 思路和代码 递归法 递归优化 全排列 题干 思路和代码 递归法 全排列Ⅱ 题干 思路和代码 方法一&#xff1a;用集合 set 去重 方法二&#xff1a;先排序&#xff0c;再用数组去重 非递减子序列 题干 题目&#xff1a;给你一个整数数组 nu…

python动画:manim中的目标位置移动,线条末端和两条线相切的位置处理

一&#xff0c;Manim中目标的位置移动 在 Manim 中&#xff0c;shift 函数用于在三维空间或二维平面上对对象进行平移。通过 shift 方法&#xff0c;用户可以快速移动场景中的物体&#xff0c;指定移动的方向和距离。方向通常由预定义的常量&#xff08;如 UP, DOWN, LEFT, RI…

opencv-python图像增强十五:高级滤镜实现

文章目录 前言二、鲜食滤镜三、巧克力滤镜三&#xff0c;冷艳滤镜&#xff1a; 前言 在之前两个滤镜文章中介绍了六种简单的滤镜实现&#xff0c;它们大多都是由一个单独函数实现的接下来介绍五种结合了之前图像增强文章提的的算法的复合滤镜。本案例中的算法来自于文章一&…

【数学建模】TOPSIS法(优劣解距离法)

TOPSIS法&#xff08;Technique for Order Preference by Similarity to Ideal Solution&#xff0c;优劣解距离法&#xff09;是一种多准则决策分析方法&#xff0c;它基于这样一个概念&#xff1a;最理想的方案应该是距离理想解最近而距离负理想解最远的方案。以下是使用TOPS…

【React原理 - 任务调度和时间分片详解】

概述 在React15的时候&#xff0c;React使用的是从根节点往下递归的方式同步创建虚拟Dom&#xff0c;由于递归具有同步不可中断的特性&#xff0c;所以当执行长任务时(通常以60帧为标准&#xff0c;即16.6ms)就会长时间占用主线程长时间无响应&#xff0c;导致页面卡顿&#x…