python神经网络实现手写数字识别实验

news2024/11/25 8:18:11

     手写数字识别实验是机器学习中最常见的一个示例,可以有很多种办法实现,最基础的其实就是利用knn算法,根据数字图片对应矩阵与经过训练的数字进行距离计算,最后这个距离最短,那么就认为它是哪个数字。

     这里直接通过神经网络的办法来进行手写数字识别实验。不借助其他框架,编写网络,然后进行测试。这个代码其实网上有很多,并不是原创。

     这里有必要说明一下手写数字的数据集,这里采用的是mnist_dataset/mnist_train.csv数据集,数据地址: https://www.kaggle.com/datasets/oddrationale/mnist-in-csv。下载之后是一个压缩包,里面包含mnist_train.csv,mnist_test.csv。

    我们可以看看mnist_train.csv的部分数据:

 

    上图中,①处表示 第一行内容 其实是标题,我们在数据处理的时候需要过滤这一行。② 表示的是label内容,也就是真实数字,它由0-9组成,也就是10个分类。③ 处表示的28 * 28矩阵,这个数字由784个数字组成。 

    实验过程,先使用mnist_train.csv数据训练网络,然后利用我们自己手写的数字进行测试。这里没有使用mnist_test.csv进行测试,主要是它本身就是人家进行测试的数据,我们这里自己测试。

    我自己准备的数字图片如下所示:

    这些图片都是根据这里测试数据mnist_train.csv数据格式的要求进行绘制的28*28像素的图片,这个图片很小,但是可以借助windows系统paint绘图工具,选择28*28像素画布,然后进行放大,最后可以在编辑区域画出这些数字。

     

     下面给出代码:

import os
import numpy as np
import scipy.special
import imageio

image_path = 'number_images'


# 加载图片
def load_img_number(root_dir):
    files = os.listdir(root_dir)
    file_list = []
    for file in files:
        file_path = os.path.join(root_dir, file)
        file_list.append(file_path)
    return file_list


class neuralnetwork:
    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
        # 输入层
        self.inodes = inputnodes
        # 隐藏层
        self.hnodes = hiddennodes
        # 输出层
        self.onodes = outputnodes
        # 学习率
        self.lr = learningrate
        # 输入层-隐藏层权重
        self.wih = (np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes)))
        # 隐藏层-输出层权重
        self.who = (np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes)))
        # 激活函数
        self.activation_function = lambda x: scipy.special.expit(x)

    def train(self, inputs_list, targets_list):
        inputs = np.array(inputs_list, ndmin=2).T
        targets = np.array(targets_list, ndmin=2).T
        hidden_inputs = np.dot(self.wih, inputs)
        hidden_outputs = self.activation_function(hidden_inputs)
        final_inputs = np.dot(self.who, hidden_outputs)
        final_outputs = self.activation_function(final_inputs)
        output_errors = targets - final_outputs
        hidden_errors = np.dot(self.who.T, output_errors)
        self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),
                                     np.transpose(hidden_outputs))
        self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))

    def query(self, inputs_list):
        inputs = np.array(inputs_list, ndmin=2).T
        hidden_inputs = np.dot(self.wih, inputs)
        hidden_outputs = self.activation_function(hidden_inputs)
        final_inputs = np.dot(self.who, hidden_outputs)
        final_outputs = self.activation_function(final_inputs)
        return final_outputs


input_nodes = 784
hidden_nodes = 200
output_nodes = 10
learning_rate = 0.2
# 构建模型
model = neuralnetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
# 准备训练数据
training_data_file = open('mnist/mnist_train.csv', 'r')
training_data_list = training_data_file.readlines()
# 去掉第一行标题
training_data_list = training_data_list[1:]
training_data_file.close()

# 训练
for record in training_data_list:
    all_values = record.split(',')
    inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
    targets = np.zeros(output_nodes) + 0.01
    targets[int(all_values[0])] = 0.99
    model.train(inputs, targets)
    pass

img_list = load_img_number(image_path)

for i in range(len(img_list)):
    img_name = img_list[i]
    img_arr = imageio.v2.imread(img_name, mode='L')
    img_data = 255.0 - img_arr.reshape(784)
    inputs = (img_data / 255.0 * 0.99) + 0.01
    outputs = model.query(inputs)
    label = np.argmax(outputs)
    print(f'{img_name} 识别结果是 {label}')

    运行代码,打印结果:

 

    1、识别率很感人,其实很多都识别错误。 

    2、多次运行,结果也不一样。

    3、识别不正确的基本会认为6或者8。不知道怎么会有这种奇怪的结果。

   最后,给出本示例的代码和资源:https://gitee.com/buejee/aitutorial

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/564652.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目的延伸

目录 推送模块 1.表 1.1 表字段 1.2 字段类型 1.3 索引 1.4 关联查询 2.参数的含义 3.以技术流的维度讲业务逻辑 4.redis 4.1基础知识 5.设计模式 5.1策略模式 5.2工厂模式 6.遇到的问题 6.1稳定性 7.锁 即时通讯模块 1.表 1.1 表字段 1.2 字段类型 1.3 索…

关于队头阻塞的一些笔记

一、队头阻塞(Head-of-Line Blocking,HOL) 看到队头,联想到了数据结构课程中学到的队列,队列的一个特点就是FIFO(First In First Out),即先进入队列的数据先出队列。所以&#xff0…

【Linux高级 I/O(6)】存储映射 I/O进阶应用(附代码示例)

mprotect()函数 使用系统调用 mprotect()可以更改一个现有映射区的保护要求&#xff0c;其函数原型如下所示&#xff1a; #include <sys/mman.h>int mprotect(void *addr, size_t len, int prot);参数 prot 的取值与 mmap()函数的 prot 参数的一样&#xff0c;m…

Pycharm 配置jupyter notebook 且Windos 安装vim编辑器

请记住要想让你的python成功安装jupyter notebook &#xff0c;你的python最好使用p大于等于python3.7 最好不要在python2大版本中安装jupyternotebook 这个会报错&#xff0c;需要你改一些配置文件&#xff0c;除非你想挑战一下自己&#xff0c;不过后面我会尝试在python2大版…

NeRF-VAE:将场景看作一个分布【ICML‘2021】

文章目录 GQN网络介绍Amortized InferenceNeRF-VAE GQN网络介绍 论文标题&#xff1a;Neural scene representation and rendering 作者&#xff1a;S. M. Ali Eslami, Danilo Jimenez Rezende, et al. 期刊&#xff1a;Science 发表时间&#xff1a;2018/06/15 该文章提出…

单视觉L2市场「鲶鱼」来了,掀起数据反哺高阶新打法

作者 | 张祥威编辑 | 德新 智驾方案的降本行动仍在推进。 早年&#xff0c;单视觉L2市场的玩家以Mobileye、博世为主&#xff0c;后来国内智驾公司加入&#xff0c;共同推动 1V、1R1V、nR1V等不同的方案兴起&#xff0c;L2近乎成为车辆的必备功能。 当下&#xff0c;在行业降低…

认识linux文件系统/文件夹名字解释

linux系统因为其高效、直接的底层操作而被很多代码开发者使用&#xff0c;谈及linux&#xff0c;大家普遍的印象就是黑乎乎的终端命令行&#xff0c;后来基于linux系统开发出来的具有可视化桌面的ubuntu版本&#xff0c;让大家的使用体验兼顾了windows系统的直观性可linux系统代…

YOLOV5使用(一): docker跑通,详解TensorRT下plugin的onnx

yolov5的工程使用(以人员检测为案例) 使用ubuntu为案例 docker run --gpus all -it -p 6007:6006 -p 8889:8888 --name my_torch -v $(pwd):/app easonbob/my_torch1-pytorch:22.03-py3-yolov5-6.0使用端口映射功能也就是说打开jupyter lab的指令是 http://localhost:8889/l…

windows先的conda环境复制到linux环境

如果是迁移的环境一致&#xff1a;同是windows或同是linux直接用这个命令即可&#xff1a; conda create -n new_env_name --clone old_env_path 如果是window的环境迁移到linux这种跨环境就不能用上面的方法&#xff0c;网上这方面的资料也很多&#xff0c;记录一下我的…

小蝌蚪找妈妈:Python之作用域链与 LEGB 原则

文章目录 参考描述作用域对象全局作用域globals() 局部作用域locals() 包含作用域内置作用域builtins 模块builtins 模块与 \_\_builtins__builtins is \_\_builtins__??? \_\_builtins__ 与内置作用域赶不走的 \_\_builtins__ 作用域链作用域链 与 LEGB 原则狗急跳墙之法 参…

【Go语言从入门到实战】基础篇

Go语言从入门到实战 — 基础篇 First Go Program 编译 & 运行 基本程序结构 应用程序入口 package mainimport "fmt"func main() {fmt.Println("Hello World") }退出返回值 package mainimport ("fmt""os" )func main() {fmt.Pr…

哪个产品功能重要?KANO模型帮你

哪个产品功能重要&#xff1f;KANO模型来帮你 模型工具可以协助思考和系统化改进 KANO模型是小日本一个教授提出 趣讲大白话&#xff1a;往往&#xff0c;怎么思考&#xff0c;比思考什么重要 【趣讲信息科技175期】 **************************** 东京理工大学教授狩野纪昭(No…

【医学图像】图像分割系列.2 (diffusion)

介绍几篇使用diffusion来实现医学图像分割的论文&#xff1a;DARL&#xff08;ICLR2023&#xff09;&#xff0c;MedSegDiff&#xff08;MIDL2023&#xff09;& MedSegDiff-V2&#xff08;arXiv2023&#xff09;&#xff0c;ImgX-DiffSeg&#xff08;arXiv2023&#xff09;…

CTF 2015: Search Engine-fastbin_dup_into_stack

参考&#xff1a; [1]https://gsgx.me/posts/9447-ctf-2015-search-engine-writeup/ [2]https://blog.csdn.net/weixin_38419913/article/details/103238963(掌握利用点&#xff0c;省略各种逆向细节) [3]https://bbs.kanxue.com/thread-267876.htm&#xff08;逆向调试详解&am…

web功能测试方法大全—完整!全面!(纯干货,建议收藏哦~)

本文通过六个部分为大家梳理了web功能测试过程中&#xff0c;容易出现的遗漏的部分&#xff0c;用以发掘自己工作中的疏漏。&#xff08;纯干货&#xff0c;建议收藏哦~&#xff09; 一、输入框 1、字符型输入框 2、数值型输入框 3、日期型输入框 4、信息重复 在一些需要命…

GPT-4版Windows炸场,整个系统就是一个对话机器人,微软开建AI全宇宙

原创 智东西编辑部 智东西 Windows的GPT时刻到来&#xff0c;变革PC行业。 作者 | 智东西编辑部 今日凌晨&#xff0c;Windows迎来了GPT-4时刻&#xff01; 在2023微软Build大会上&#xff0c;微软总裁萨蒂亚纳德拉&#xff08;Satya Nadella&#xff09;宣布推出Windows Co…

实现免杀:Shellcode的AES和XOR加密策略(vt查杀率:4/70)

前言 什么是私钥和公钥 私钥和公钥是密码学中用于实现加密、解密和数字签名等功能的关键组件。 私钥是一种加密算法中的秘密密钥&#xff0c;只有密钥的拥有者可以访问和使用它。私钥通常用于数字签名和数据加密等场景中&#xff0c;它可以用于对数据进行加密&#xff0c;同…

头部效应凸显,消金行业迈入“巨头赛”?

回顾已经过去的2022年&#xff0c;消金行业面临着来自多方面的考验&#xff0c;承压前行&#xff0c;而随着进入2023年&#xff0c;相关企业也陆续展示出过去一年的发展成果&#xff0c;以此为后续发展做出指引。 当前&#xff0c;30家已开业的消金公司中&#xff0c;29家的20…

《消息队列高手课》课程笔记(三)

如何利用事务消息实现分布式事务&#xff1f; 什么是分布式事务&#xff1f; 消息队列中的“事务”&#xff0c;主要解决的是消息生产者和消息消费者的数据一致性问题。如果我们需要对若干数据进行更新操作&#xff0c;为了保证这些数据的完整性和一致性&#xff0c;我们希望…

独立站怎么搭建?搭建一个独立站的10个建议和步骤

要搭建一个独立站&#xff08;也称为个人网站或博客&#xff09;&#xff0c;以下是一些建议和步骤&#xff1a; 选择一个合适的域名&#xff1a;选择一个简洁、易记且与您网站内容相关的域名。确保域名可用&#xff0c;并注册该域名。 寻找一个合适的主机服务提供商&#xff…