机器学习11-前馈神经网络识别手写数字1.0

news2024/11/13 15:07:56

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成:

1. 输入层
输入层接收输入数据,这里是一个 28x28 的灰度图像,每个像素值表示图像中的亮度值。

2. Flatten 层
Flatten 层用于将输入数据展平为一维向量,以便传递给后续的全连接层。在这里,我们将 28x28 的图像展平为一个长度为 784 的向量。

3. 全连接层(Dense 层)
全连接层是神经网络中最常见的层之一,每个神经元与上一层的每个神经元都连接。在这里,我们有一个包含 128 个神经元的隐藏层,以及一个包含 10 个神经元的输出层。隐藏层使用 ReLU(Rectified Linear Unit)激活函数,输出层使用 softmax 激活函数。

4. 输出层
输出层产生神经网络的输出,这里是一个包含 10 个元素的向量,每个元素表示对应类别的概率。softmax 函数用于将网络的原始输出转换为概率分布。

5. 编译模型
在编译模型时,我们指定了优化器(optimizer)和损失函数(loss function)。在这里,我们使用 Adam 优化器和稀疏分类交叉熵损失函数。

6. 训练模型
使用训练数据集对模型进行训练,以学习如何将输入映射到正确的输出。在训练过程中,模型通过优化损失函数来调整权重和偏置,使其尽可能准确地预测输出。

总的来说,这个神经网络是一个经典的多层感知器(MLP),它在输入层和输出层之间包含一个或多个隐藏层,通过学习逐步提取和组合特征来进行分类或回归任务。

代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 数据预处理
train_images = train_images / 255.0
test_images = test_images / 255.0

# 构建神经网络模型
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(train_images, train_labels, epochs=5)

# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)

# 保存模型
model.save('mnist_model.h5')

# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')

# 使用加载的模型进行预测
predictions = loaded_model.predict(test_images)

结果:

Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2586 - accuracy: 0.9265
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1136 - accuracy: 0.9656
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0773 - accuracy: 0.9768
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0587 - accuracy: 0.9823
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0462 - accuracy: 0.9855
313/313 [==============================] - 0s 1ms/step - loss: 0.0750 - accuracy: 0.9775

Test accuracy: 0.9775000214576721

识别准确率挺高,然后我们也得到了训练好的模型

应用测试:

import tensorflow as tf
import numpy as np
from PIL import Image

# 加载保存的模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')

# 打开手写图片文件
image_path = 'pic/handwritten_digit_thick_5.png'  # 修改为你的手写图片文件路径
image = Image.open(image_path).convert('L')  # 转换为灰度图像

# 调整图片大小为 28x28 像素
image = image.resize((28, 28))

# 将图片转换为 NumPy 数组并进行归一化处理
image_array = np.array(image) / 255.0

# 将图片转换为模型输入的格式(添加批次维度)
input_image = np.expand_dims(image_array, axis=0)

# 使用模型进行预测
predictions = loaded_model.predict(input_image)

# 获取预测结果(最大概率的类别)
predicted_class = np.argmax(predictions)

print('Predicted digit:', predicted_class)

准备了4张图片,3张自己手写,1张摘自minst:

前两张画笔比较细,第三张是minst的5,第四张是用了粗笔自己写的5,最终结果是就minst预测对了。

Predicted digit: 2

Predicted digit: 8

Predicted digit: 5

Predicted digit: 3


结论:

可见这个模型的扩展适应性能还是不够,只能预测正确训练过的minst数字。

改进:

想办法提升训练的质量,让预测能力达标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1441320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Github 2024-02-09 开源项目日报 Top10

根据Github Trendings的统计,今日(2024-02-09统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目4Go项目2Scala项目1PLpgSQL项目1Ruby项目1HTML项目1Solidity项目1Lua项目1 开源个人理财应用 Mayb…

4.5 特效规范与拆分实现及程序的调用原理

一、特效基础流程 落地方案 连入游戏 需求 策划需求,美术需求 需要的SHADER,功能 测试/反馈/修改 效果迭代 满足功能的特效 概念设计 参考图,设计图 二、规范的设计原理与目的 节约沟通成本 保持项目的一致性 工作交接可以更加便捷 降低出错的概率 提升工作效率…

236. 二叉树的最近公共祖先 - 力扣(LeetCode)

题目描述 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q,最近公共祖先表示为一个节点 x,满足 x 是 p、q 的祖先且 x 的深度尽可能大(一个节点也可以…

【Linux】SystemV IPC

进程间通信 一、SystemV 共享内存1. 共享内存原理2. 系统调用接口(1)创建共享内存(2)形成 key(3)测试接口(4)关联进程(5)取消关联(6)释…

以用户为中心,酷开科技荣获“消费者服务之星”

在企业顺应消费升级的道路中,企业自身不仅要着力强化对于消费者服务意识的提升,并且要树立诚信自律的行业示范带头作用,助力消费环境稳中向好,不断满足人民群众对美好生活的期待。企业的发展需要消费者的认可,酷开科技…

震撼!谷歌推出AI大模型Gemini Ultra,7胜GPT-4!这是AI的新里程碑还是终结者?

谷歌的多模态AI模型Gemini再升级,其中的Ultra版本在基准测试中大放异彩,力压GPT-4! Gemini Ultra,处理文本、代码、图像、音频、视频等模态游刃有余,复杂推理也不在话下。在与GPT-4的较量中,它以7胜1负的…

C#,聚会数(相遇数,Rencontres Number)的算法与源代码

1 相遇数 相遇数(Rencontres Number,partial derangement numbers)是指部分扰动的数量,或与独立对象的r相遇的置换数(即具有固定点的独立对象的置换数)。 看不通。懂的朋友给解释一下哈。 2 源程序 using…

极值图论基础

目录 一,普通子图禁图 二,Turan问题 三,Turan定理、Turan图 1,Turan定理 2,Turan图 四,以完全二部图为禁图的Turan问题 1,最大边数的上界 2,最大边数的下界 五,…

按键扫描16Hz-单片机通用模板

按键扫描16Hz-单片机通用模板 一、按键扫描的原理1、直接检测高低电平类型2、矩阵扫描类型3、ADC检测类型二、key.c的实现1、void keyScan(void) 按键扫描函数①void FHiKey(void) 按键按下功能②void FSameKey(void) 按键长按功能③void FLowKey(void) 按键释放功能三、key.h的…

PlantUML绘制UML图教程

UML(Unified Modeling Language)是一种通用的建模语言,广泛用于软件开发中对系统进行可视化建模。PlantUML是一款强大的工具,通过简单的文本描述,能够生成UML图,包括类图、时序图、用例图等。PlantUML是一款…

FPGA_简单工程_无源蜂鸣器驱动实验

一 理论 蜂鸣器按其结构可分为电磁式蜂鸣器和压电式蜂鸣器2中类型,按其有无信号源,分为有源蜂鸣器和无源蜂鸣器。 有源蜂鸣器,内部装有集成电路,不需要音频驱动电路,就直接能发出声响,而无源蜂鸣器&#…

8个简约精美的WordPress外贸网站主题模板

Simplify WordPress外贸网站模板 Simplify WordPress外贸网站模板,简洁实用的外贸公司wordpress外贸建站模板。 查看演示 Invisible Trade WP外贸网站模板 WordPress Invisible Trade外贸网站模板,做进出口贸易公司官网的wordpress网站模板。 查看演…

网友:感谢华为救了我的下半生。

(关注数据结构和算法,了解更多新知识) 最近一位网友发视频称,华为Mate60 Pro帮他挡了子弹。视频配文:“一场意外,没有这个手机隔挡,下半生我可能就在轮椅上度过了!”视频中,手机摄像头右侧被击中…

TS学习与实践

文章目录 学习资料TypeScript 介绍TypeScript 是什么?TypeScript 增加了什么?TypeScript 开发环境搭建 基本类型编译选项类声明属性属性修饰符getter 与 setter方法static 静态方法实例方法 构造函数继承 与 super抽象类接口interface 定义接口implement…

git flow与分支管理

git flow与分支管理 一、git flow是什么二、分支管理1、主分支Master2、开发分支Develop3、临时性分支功能分支预发布分支修补bug分支 三、分支管理最佳实践1、分支名义规划2、环境与分支3、分支图 四、git flow缺点 一、git flow是什么 Git 作为一个源码管理系统,…

SQL--事务

事务简介 事务 是一组操作的集合,它是一个不可分割的工作单位,事务会把所有的操作作为一个整体一起向系 统提交或撤销操作请求,即这些操作要么同时成功,要么同时失败。 就比如: 张三给李四转账1000块钱,张三银行账户…

适用于 Windows 的 6 款 iPhone 数据恢复软件

数据恢复 已经取得了长足的进步。从仅提供恢复数据的可能性到保证数据恢复,有许多适用于 Windows的第三方 iPhone 数据恢复软件。 大多数软件都是高级工具,但是提供了出色的数据恢复解决方案。从iPhone恢复数据非常简单。 只需将 iPhone 连接到您的计算…

蓝桥杯刷题day08——完全日期

1、题目描述 如果一个日期中年月日的各位数字之和是完全平方数,则称为一个完全日期。 例如:2021年6月5日的各位数字之和为20216516,而16是一个完全平方数,它是4的平方。所以2021年6月5日是一个完全日期。 请问,从200…

计算机毕业设计Python+django医院后勤服务系统flask

结合目前流行的 B/S架构,将医疗后勤服务管理的各个方面都集中到数据库中,以便于用户的需要。该平台在确保平台稳定的前提下,能够实现多功能模块的设计和应用。该平台由管理员功能模块,工作人员模块,患者模块,患者家属模…

git 使用 (备查)

git忽略清单 添加忽略清单 SSH免登录 ssh协议可以实现免登录操作,身份验证通过密钥实现。 跨团队写作 解决冲突 拉取 克隆 拉取最新版本 推送 远程仓库别名 直接使用git push推送 多人协作开发 分支命令 合并分支命令在主分支使用,将develop分支合并到…