第T1周:Tensorflow实现mnist手写数字识别

news2024/9/20 1:53:45
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框架: TensorFlow
**(二)具体步骤:

  1. 安装TensorFlow
    第一次使用这个框架,先安装,打开官网:TensorFlow:
    image.png
# 先把PIP升级到最新版本
$ pip install --upgrade pip  
# 安装稳定版,支持CPU和GPU
$ pip install tensorflow

演示一下官方的代码看看能不能跑(我也看不懂是什么意思,就当是hello world,看看TF正常不):
image.png
image.png
跑成功了(下图),那说明我们安装也成功了。
image.png
下面就通过具体代码来熟悉熟悉TF的使用。
2. 使用TensorFlow实现MNIST手写数字识别
2.1 设置GPU
一上来就整高阶的GPU运算,大家如果没有显卡 ,可以使用CPU(应该默认就是使用CPU),那么本步骤可以直接忽略.

import tensorflow as tf  
print("可用的GPU数量: ", len(tf.config.list_physical_devices('GPU')))

image.png
我的机器明明有显卡,但是显示0,不管了,后面再研究。
选择GPU的代码:

import tensorflow as tf  
print("可用的GPU数量: ", len(tf.config.list_physical_devices('GPU')))  
  
gpus = tf.config.list_physical_devices("GPU")  
  
if gpus:  
    gpu0 = gpus[0]  # 如果有多个GPU,则使用第0个GPU  
    tf.config.experiment.set_memory_growth(gpu0, True)  
    tf.config.set_variable_device([gpu0], "GPU")

2.2 导入MINST数据

from tensorflow.keras import datasets, layers, models
# 导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签  
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

2.3 归一化

# 归一化  
train_images, test_images = train_images / 255.0, test_images / 255.0  
#查看数据形状
print('train_images.shape:', train_images.shape)  
print('test_images.shape: ', test_images.shape)  
print('train_labels: ', train_labels.shape)  
print('test_labels: ', test_labels.shape)

image.png
2.4 把数据可视化看看

# 显示数据集前50个图片数据看看  
plt.figure(figsize=(20, 10))  # 将图片显示大小设置为 20宽,10长的大小 ,单位英寸(inch)  
# 遍历MNIST数据集,下标0-49  
for i in range(50):  
    # 将整个figure分成5行10列,绘制第i+1个子图  
    plt.subplot(5, 10, i+1)  
    plt.xticks([])  # 不显示X轴刻度  
    plt.yticks([])  # 不显示Y轴刻度  
    plt.grid(False)     # 不显示网格线  
    plt.imshow(train_images[i], cmap=plt.cm.binary)  # 显示图片  
    plt.xlabel(train_labels[i])  # 显示图片对应的数字(标签)  
  
plt.show()

image.png
2.5 调整图片格式(数据形状)
为啥要调整图片格式呢,导入数据的时候,图片的形状是这样的(60000, 28,28)意思是有6000张28X28像素的图片,现在要调整成(60000, 28, 28, 1)的形状,为啥要调整形状?因为神经网络使用的数量(图像表)它的形状应该是(样本数、宽、高、通道数),对应到(60000, 28, 28)就是样本数60000张图片,宽28,高28都有了,差一个通道数。按我的理解,MNIST数据集图片是单通道图片,因此后面应该通道数是1。可以先学习一下什么叫张量表示(张量简介  |  TensorFlow Core):
image.png
image.png
image.png

再理解一张图片,通常是由RGB三通道构成的,如下:
image.png
image.png

# 调整数据格式,使用reshape来调整  
test_images = test_images.reshape((60000, 28, 28, 1))  
train_images = train_images.reshape((60000, 28, 28, 1))  
  
# 查看数据形状  
print('train_images.shape:', train_images.shape)  
print('test_images.shape: ', test_images.shape)  
print('train_labels: ', train_labels.shape)  
print('test_labels: ', test_labels.shape)

image.png
格式已经调整过来了。有人可能问train_labels和test_labels怎么不调整格式,记住这两个不是图片,是标签值,不用调整。
2.6 构建CNN网络模型(重头戏)
CNN的概念:Convolutional Neural Network,卷积神经网络)是一种前馈神经网络,特别适用于处理具有网格结构的数据,如图像或时间序列数据。CNN最初是为‌图像识别任务而设计的,但后来也被广泛应用于其他领域。
CNN的工作原理:CNN通过一系列方法成功将数据量庞大的图像识别问题不断降维,最终使其能够被训练。其工作原理主要包括以下几个部分:

  1. 卷积层:用于提取输入数据的局部特征。
  2. 池化层:用于降低特征的空间维度,减少计算量。
  3. 全连接层:用于分类或回归任务。
    CNN的特点包括局部连接、权重共享和空间层次结构,这些特点使得CNN在处理图像等数据时非常高效。
    CNN的应用领域:由于CNN在图像识别方面的出色表现,它已经被广泛应用于各种图像处理任务中。此外,CNN也被应用于‌自然语言处理和‌语音识别等领域。近年来,随着‌深度学习的发展,CNN已经成为图像分类的黄金标准。
    CNN的构建:话说把一头大像装冰箱总共需要几步,也是三步:第一步打开冰箱门,第二步放进冰箱(怎么装时冰箱的你别管),第三步关上冰箱门。CNN的构建简单来讲就是三步,第一步准备数据集(输入层),第二步做卷积运算(怎么运算的,目前还是一个黑盒子),第三步是输出结果(输出层,我们期望它输出的内容)。输入和输出大家都已经清楚了,就是这个第二步一直没闹明白,这个黑盒子内部是怎么搞的,就这么神奇的实现了各种分类,归类。今天就是来实现第二步,看看怎么搞出来的。
    CNN的模型:其实第二步的黑盒子就是神经网络模型,而模型有千千万(你自己也可以搞,只是效果怎么样就不知道了),知名的有:
    image.png
    TF框架给了我们搭建模型的方法,我们就找一个模型来试试:
    image.png
    从左到右,一步一步通过TF的方法来实现这个神经网络模型,代码中解释:
  
# 搭建模型  
model = models.Sequential([  
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),  
    layers.MaxPooling2D((2, 2)),  
    layers.Conv2D(64, (3, 3), activation='relu'),  
    layers.MaxPooling2D((2, 2)),  
  
    layers.Flatten(),  
    layers.Dense(64, activation='relu'),  
    layers.Dense(10)  
])  
  
# 打印网络结构  
print(model.summary())

image.png
解释:
Conv2D:二维卷积层,基本都用这个。
activation=‘relu’:激活函数使用ReLu函数。
MaxPooling2D: 池化层
Flatten:连接卷积层和全连接层。把张量展平。
Dense: 全连接层和输出层
image.png
2.6 编译模型

# 编译模型  
model.compile(  
    optimizer="adam",   # 设置优化器为Adam优化器  
    # 设置损失函数为交叉熵损失函数(tf.keras.losses.SparseCategoricalCrossentropy())  
    # from_logits为True时,会将y_pred转化为概率(用softmax),否则不进行转换,通常情况下用True结果更稳定  
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),  
    # 设置性能指标列表,将在模型训练时监控列表中的指标  
    metrics=['accuracy']

2.7 训练模型

# 训练模型  
history = model.fit(  
    # 输入训练集数据  
    train_images,  
    # 输入训练集标签  
    train_labels,  
    # 设置epoch为10,第一个epoch将会把所有数据输入模型完成一次训练  
    epochs=10,  
    # 设置验证集  
    validation_data=(test_images, test_labels)  
)

image.png
2.8 用这个网络模型来进行预测吧
找一张test_images里的照片先预测一下,看实际图片是什么:

plt.imshow(test_images[1])   # 上面的代码中test_images已经归一化了,可能显示不出来,可以使用归一化前的test_images看图片

image.png
拿这张照片预测一下:

pre = model.predict(test_images)  
print(pre[1])

image.png
数值最大的是第3个数,按照对应,第3个数就是2(0,1,2…这个顺序)。所以预测是对的。
改进一下,直观一点:

# 预测  
print(test_images[1].shape)  
plt.imshow(test_images[1].reshape(28, 28))  
pre = model.predict(test_images)  
print(pre[1])  
print("预测结果是:", np.argmax(pre[1]))

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2139946.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mixtral 8x7B:开源稀疏混合专家模型的新里程碑

人工智能咨询培训老师叶梓 转载标明出处 随着大模型规模的增大,计算成本和资源消耗也相应增加,这限制了它们的应用范围和效率。本论文介绍了一种新的稀疏混合专家模型(SMoE)——Mixtral 8x7B,它在保持较小计算成本的同…

【C++】c++ 11

目录 前言 列表初始化 std::initializer_list 右值引用和移动拷贝 左值和右值 左值引用和右值引用的区别 万能引用(引用折叠) 完美转发 默认成员函数控制 列表初始化 在C98中,标准允许使用花括号{}对数组或者结构体元素进行统一的列…

Gartner 成熟度曲线报告解读(一)| 2024中国IT基础设施使用趋势、影响中国IT使用的4大因素

近些年,面对数字化转型、信息化发展、政策监管与地缘政治等外部因素,以及降本增效的内部需求,不少中国企业在制定 IT 基础设施发展策略时遇到多重挑战。为帮助国内企业用户优化基础设施战略,Gartner 近日发布《中国 IT 基础设施技…

【HCIA-Datacom】华为VRP系统

| 👉个人主页:Reuuse 希望各位多多支持!❀ | 👉往期博客:网络参考模型 | 最后如果对你们有帮助的话希望有一个大大的赞! | ⭐你们的支持是我最大的动力!⭐ | 目录 1. 华为VRP系统概述VRP概念设备…

Docker-compose:管理多个容器

Docker-Compose 是 Docker 公司推出的一个开源工具软件,可以管理多个 Docker 容器组成一个应用。用户需要定义一个 YAML 格式的配置文件 docker-compose.yml,写好多个容器之间的调用关系。然后,只要一个命令,就能同时启动/关闭这些…

七、垃圾收集器ParNewCMS与底层三色标记算法详解

文章目录 垃圾收集算法分代收集理论标记-复制算法标记-清除算法标记-整理算法 垃圾收集器1.1 Serial收集器(-XX:UseSerialGC -XX:UseSerialOldGC)1.2 Parallel Scavenge收集器(-XX:UseParallelGC(年轻代),-XX:UseParallelOldGC(老年代))1.3 ParNew收集器(-XX:UseParNewGC)1.4 C…

POSIX信号量以及利用POSIX信号量实现基于循环队列的高效生产者消费者模型

🍑个人主页:Jupiter. 🚀 所属专栏:Linux从入门到进阶 欢迎大家点赞收藏评论😊 目录 🍁POSIX信号量 🍁信号量的相关接口介绍*初始化信号量**销毁信号量**等待信号量**发布信号量* 🍁&…

YOLOv9 简介

YOLO v9 是目前表现最佳的目标检测器之一,被视为现有 YOLO 变体(如 YOLO v5、YOLOX 和 YOLO v8)的改进版本。 YOLOv9 在实时目标检测领域取得了重大进展,引入了诸如可编程梯度信息(PGI)和通用高效层聚合网…

后端开发刷题 | 打家劫舍

描述 你是一个经验丰富的小偷,准备偷沿街的一排房间,每个房间都存有一定的现金,为了防止被发现,你不能偷相邻的两家,即,如果偷了第一家,就不能再偷第二家;如果偷了第二家&#xff0…

Dina靶机详解

靶机下载 https://www.vulnhub.com/entry/dina-101,200/ 靶机配置 默认是桥接模式,切换为NAT模式后重启靶机 主机发现 arp-scan -l 端口扫描 nmap -sV -A -T4 192.168.229.157 发现80端口开启,访问 访问网站 目录扫描 python dirsearch.py -u http…

1.2 交换技术

欢迎大家订阅【计算机网络】学习专栏,开启你的计算机网络学习之旅! 文章目录 前言一、电路交换1. 定义与原理2. 工作过程3. 优点与局限 二、分组交换1. 定义与原理2. 工作过程3. 优点与局限 三、报文交换1. 定义与原理2. 工作过程3. 优点与局限 四、比较…

改进RRT*的路径规划算法

一、RRT算法 RRT 算法是一种基于随机采样的快速搜索算法。该算法的主要思想是通过随机采样来创建一个快速探索的树,从而生长出一条从起点到终点的路径。如图为随机树的生长过程。 初始化。首先,初始化起始点和目标点位置,并将起点作为根节点…

printf()函数的全面介绍及用法——简单易懂

printf()函数介绍 目录 printf()函数介绍 一:头文件 二:格式控制字符串 1.格式字符。 2.转义字符。 3.普通字符。 三:格式字符输出示例 1. %c-----------输出字符 2. %s-----------输…

Linux中断实操-概念

1、裸机中的中断处理方法: (1)使能中断、初始化相应寄存器 (2)注册中断服务函数,向irqTable数组的指定标号处写入中断服务函数 (3)中断发生后进入IRQ中断服务函数,执行对…

【0~1】实现一个精简版的Tomcat服务器

真正的勇气,是在知道生活的真相之后,依然热爱生活。 《To Kill a Mockingbird》 01 Tomcat 介绍 Tomcat 是一个开源的 Java 应用服务器,主要用来运行基于 Servlet 和 JSP 技术的 Web 应用。Tomcat 实现了 Servlet 规范和 JSP 规范&#xff0…

一次RPC调用过程是怎么样的?

注册中心 RPC(Remote Procedure Call)翻译成中文就是 {远程过程调用}。RPC 框架起到的作用就是为了实现,调用远程方法时,能够做到和调用本地方法一样,让开发人员更专注于业务开发,不用去考虑网络编程等细节…

【开源免费】基于SpringBoot+Vue.JS企业客户管理系统(JAVA毕业设计)

本文项目编号 T 036 ,文末自助获取源码 \color{red}{T036,文末自助获取源码} T036,文末自助获取源码 目录 一、系统介绍1.1 管理员角色1.2 普通员工角色1.3 系统特点 二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内…

苹果手机备份照片怎么删除

在数字时代,备份照片是保护我们珍贵记忆不受意外丢失影响的一种重要方式。苹果手机用户通常利用iCloud或iTunes来备份他们的照片,确保数据的安全。然而,随着时间的推移,这些备份可能会积累大量不再需要的照片,占用宝贵…

鸿蒙开发之ArkTS 基础二

ArkTS常用的基础数据类型 1.字符串 关键字是string 2.数字 关键字是number 3.布尔 关键字是boolean 语法格式是:let 变量名:变量类型 变量值 其中let是关键表示变量,可以修改,可以改变一只对应的是const 修饰,常量不能修改,…

Python画笔案例-050 绘制天空之眼

1、绘制天空之眼 通过 python 的turtle 库绘制 天空之眼,如下图: 2、实现代码 绘制 天空之眼,以下为实现代码: """天空之眼.py """ import math import turtledef draw_square(length,level):if l…