【TensorFlow】T1:实现mnist手写数字识别

news2025/2/7 2:20:52
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

1、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]
    tf.config.experimental.set_memory_growth(gpu0, True)
    tf.config.set_visible_devices([gpu0],"GPU")
    
print(gpus)
# 输出:[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

固定模板,直接调用即可

2、导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

# 导入mnist数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
# train_images -- 训练集图片
# train_labels -- 训练集标签
# test_images  -- 测试集图片
# test_labels  -- 测试机标签

此处所需数据直接联网下载,故无需数据集

3、归一化

将图像数据从0-255缩放到0-1,主要是为了加快训练速度和提高模型的性能,使得模型训练过程更加稳定高效:

  1. 加速收敛:帮助优化算法更快找到最优解。
  2. 提升性能:确保不同规模的数据对模型的影响均衡,提高预测准确性。
  3. 简化调参:使超参数的选择更加简单有效。
  4. 数值稳定:避免因数据尺度问题导致的计算异常,如梯度爆炸。
train_images, test_images = train_images / 255.0, test_images / 255.0
print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)

4、可视化

# figure函数--设置画布大小
# figsize--指定了画布的宽度(20英寸)和高度(10英寸)
plt.figure(figsize=(20, 10))

# 遍历前20张图片
for i in range(20):
    # subplot函数--创建子图
    # 2--2行,10--10列,i+1--第i+1个位置
    plt.subplot(2, 10, i+1)

    # xticks函数--隐藏x轴
    plt.xticks([])

    # yticks函数--隐藏y轴
    plt.yticks([])

    # grid函数--关闭网格线
    plt.grid(False)

    # imshow函数--绘制图像
    # train_images[i]--需要显示的图片,cmap--设置颜色映射(binary)
    plt.imshow(train_images[i], cmap=plt.cm.binary)

    # xlabel函数--给图像添加标签
    plt.xlabel(train_labels[i])

plt.show()

输出:
在这里插入图片描述

5、调整格式

# reshape函数--只改变数据形状,不改变数据内容
# 6000--6000个样本,28--28*28像素,1--颜色通道数(灰度)
train_images = train_images.reshape((60000, 28, 28, 1))
# 1000--1000个样本,28--28*28像素,1--颜色通道数(灰度)
test_images = test_images.reshape((10000, 28, 28, 1))

print(train_images.shape, train_labels.shape,test_images.shape, test_labels.shape)
# 输出:(60000, 28, 28, 1) (60000,) (10000, 28, 28, 1) (10000,)

6、构建CNN网络模型

关键组件:

  1. 卷积层(Conv2D):主要用于提取图像的特征。它通过滑动窗口(即滤波器或核)在输入图像上移动,计算每个位置上的点积来生成特征图。每个滤波器可以识别特定类型的特征,如边缘或纹理。
  2. 激活函数(Activation function):如ReLU(Rectified Linear Unit),用于引入非线性因素,使模型能够学习更复杂的模式。
  3. 最大池化层(MaxPooling2D):用于减小特征图的空间维度(宽度和高度),同时保留最重要的信息。这有助于减少计算量并控制过拟合。
  4. 展平层(Flatten)将多维的卷积层输出转换为一维向量,以便将其输入到全连接层中。
  5. 全连接层(Dense):在这里,每个神经元与前一层的所有神经元相连,用于最终的分类任务。最后一层通常不使用激活函数(或者使用softmax函数),以直接输出每个类别的得分。
# Sequential函数--创建了一个Sequential模型,允许按顺序堆叠各层
model = models.Sequential([
    # Conv2D--第一个二维卷积层
    # 32--32个滤波器,(3, 3)--每个滤波器的大小为(3, 3),activation0--激活函数(relu),input_shape--输入数据的大小(28*28像素、单通道)
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),

    # MaxPooling2D--第一个二维最大池化层
    # 2--池化窗口的大小为2*2,用于下采样时减少特征维度
    layers.MaxPooling2D(2, 2),

    # Conv2D--第二个二维卷积层
    layers.Conv2D(64, (3, 3), activation='relu'),

    # MaxPooling2D--第二个二维最大池化层
    layers.MaxPooling2D(2, 2),

    # Flatten--将多维的卷积层输出展平为一维向量,方便输入全连接层
    layers.Flatten(),

    # Dense--全连接层
    # 64--64个节点,activation0--激活函数(relu)
    layers.Dense(64, activation='relu'),

    # Dense--输出层
    # 10--10个节点(此处对应0-9),默认线性激活函数
    layers.Dense(10)
])

# 打印模型结构和参数信息,包括每一层的输出形状和参数数量
print(model.summary())

输出:
在这里插入图片描述

7、编译模型

# compile函数--配置模型的编译设置
model.compile(
    # optimizer--优化器(adam)
    # Adam:一种基于一阶梯度的优化算法,能自适应地调整不同参数的学习率,适用于大规模数据和高维空间问题
    optimizer='adam',

    # loss--损失函数(Sparse Categorical Crossentropy)
    # Sparse Categorical Crossentropy(稀疏分类交叉熵损失):适用于多分类问题,特别是标签是整数形式时
    # from_logits=True:表示网络输出未经过softmax激活函数处理。这种情况下,损失函数会自动应用softmax来计算最终的分类概率
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),

    # metrics--性能指标
    # accuracy(准确率):表示预测正确的样本数占总样本数的比例
    metrics=['accuracy']
)

8、训练模型

  1. 主要参数:
  • x训练数据的输入(特征)。它可以是一个 Numpy 数组或者一个 TensorFlow 数据集等。对于图像数据,这通常是一个形状为 (样本数量, 高度, 宽度, 颜色通道) 的四维数组。
  • y目标数据(标签)。与 x 对应,表示每个输入样本的真实类别或值。对于分类任务,这通常是一个一维数组,长度等于训练样本的数量;对于多分类问题,可能是一个二维数组,采用 one-hot 编码。
  • batch_size每次梯度更新时使用的样本数。默认值是 32。较大的批量大小可以使计算更高效,但需要更多的内存。
  • epochs整个训练集迭代的次数。每次迭代称为一个 epoch。增加 epoch 数量可以让模型有更多机会学习数据中的模式,但也增加了过拟合的风险。
  • verbose:日志显示模式。0 表示不输出日志到屏幕上,1 表示输出进度条记录,2 表示每个epoch输出一行记录。
  • validation_data:在每个 epoch 结束后用于评估模型性能的数据集。这是一个元组 (x_val, y_val) 或者一个包含输入和目标数据的列表。这对于监控模型在未见过的数据上的表现非常重要。
  • validation_split:从训练数据中划分出一定比例的数据作为验证集,取值范围是 (0, 1)。注意,这个参数只会在 x 是 Numpy 数组时有效。
  • shuffle:在每个 epoch 开始前是否打乱训练数据。这对于确保模型不会因为数据顺序而产生偏差非常重要,默认值为 True。
  • callbacks:一个列表,其中包含各种回调函数对象,在训练过程中这些回调函数会在特定时间点被调用(如每轮结束、训练开始或结束等),可用于实现提前停止、动态调整学习率等功能。
  1. 返回值:
  • History对象:该对象包含了一个 history 属性,这是一个字典,包含了训练过程中损失值评估指标的变化情况。可以访问 history.history[‘loss’] 来获取每个 epoch 的训练损失值,或 history.history[‘val_accuracy’] 获取每个 epoch 的验证准确率。
    “”"
history = model.fit(
    train_images,
    train_labels,
    epochs=10,
    validation_data=(test_images, test_labels),
)

9、模型预测

plt.imshow(test_images[1])
plt.show()

pre = model.predict(test_images)
print(pre[1])
# 输出:[-0.0146451 0.11037247 -0.01110678 0.03087252 -0.02923543 -0.10968889 -0.00841374 0.04551534  -0.02969249 -0.00869128]

输出:
在这里插入图片描述

10、完整代码

# 1.设置GPU
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]
    tf.config.experimental.set_memory_growth(gpu0, True)
    tf.config.set_visible_devices([gpu0], "GPU")

print(gpus)

# 2.导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

# 3.归一化
train_images, test_images = train_images / 255.0, test_images / 255.0

print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)

# 4.可视化
plt.figure(figsize=(20,10))
for i in range(20):
    plt.subplot(2,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(train_labels[i])
    
plt.show()

# 5.调整图片格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

print(train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)

# 6.构建CNN网络模型
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10)
])
print(model.summary())

# 7.编译模型
model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

# 8.训练模型
history = model.fit(
	train_images,
	train_labels,
	epochs=10,
    validation_data=(test_images, test_labels))

# 9.模型预测
plt.imshow(test_images[1])
plt.show()

pre = model.predict(test_images)
print(pre[1])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2294053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ArcGIS_Python】使用arcpy脚本将shape数据转换为三维白膜数据

说明: 该专栏之前的文章中python脚本使用的是ArcMap10.6自带的arcpy(好几年前的文章),从本篇开始使用的是ArcGIS Pro 3.3.2版本自带的arcpy,需要注意不同版本对应的arcpy函数是存在差异的 数据准备:准备一…

动静态库的学习

动静态库中,不需要包含main函数 文件分为内存级(被打开的)文件和磁盘级文件 库 每个程序都要依赖很多基础的底层库,本质上来说库是一种可执行代码的二进制形式,可以被载入内存执行 静态库 linux .a windows .lib 动态库 linux .…

DeepSeek的多模态AI模型-Janus-pro,可生图,可读图

简介 Janus-Pro 是由 DeepSeek 开发的一款多模态理解与生成模型,是 Janus 模型的升级版。它能够同时处理文本和图像,既能理解图像内容,又能根据文本描述生成高质量图像。Janus-Pro 的核心目标是通过解耦视觉编码路径,解决多模态理…

Python爬虫实战:一键采集电商数据,掌握市场动态!

电商数据分析是个香饽饽,可市面上的数据采集工具要不贵得吓人,要不就是各种广告弹窗。干脆自己动手写个爬虫,想抓啥抓啥,还能学点技术。今天咱聊聊怎么用Python写个简单的电商数据爬虫。 打好基础:搞定请求头 别看爬虫…

游戏引擎 Unity - Unity 下载与安装

Unity Unity 首次发布于 2005 年,属于 Unity Technologies Unity 使用的开发技术有:C# Unity 的适用平台:PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域:开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

理解 C 与 C++ 中的 const 常量与数组大小的关系

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C语言 文章目录 💯前言💯数组大小的常量要求💯C 语言中的数组大小要求💯C 中的数组大小要求💯为什么 C 中 const 变量可以作为数组大小💯进一步的…

孟加拉国_行政边界省市边界arcgis数据shp格式wgs84坐标

这篇内容将深入探讨孟加拉国的行政边界省市边界数据,该数据是以arcgis的shp格式提供的,并采用WGS84坐标系统。ArcGIS是一款广泛应用于地理信息系统(GIS)的专业软件,它允许用户处理、分析和展示地理空间数据。在GIS领域…

Java设计模式:行为型模式→状态模式

Java 状态模式详解 1. 定义 状态模式(State Pattern)是一种行为型设计模式,它允许对象在内部状态改变时改变其行为。状态模式通过将状态需要的行为封装在不同的状态类中,实现对象行为的动态改变。该模式的核心思想是分离不同状态…

快速幂,错位排序笔记

​ 记一下刚学明白的快速幂和错位排序的原理和代码 快速幂 原理: a^b (a^(b/2)) ^ 2(b为偶数) a^b a*(a^( (b-1)/2))^2(b为奇数) 指数为偶数时…

机器人基础深度学习基础

参考: (1)【具身抓取课程-1】机器人基础 (2)【具身抓取课程-2】深度学习基础 1 机器人基础 从平面二连杆理解机器人学 正运动学:从关节角度到末端执行器位置的一个映射 逆运动学:已知末端位置…

Java语法进阶

目录: Object类、常用APICollection、泛型List、Set、数据结构、CollectionsMap与斗地主案例异常、线程线程、同步等待与唤醒案例、线程池、Lambda表达式File类、递归字节流、字符流缓冲流、转换流、序列化流、Files网络编程 十二、函数式接口Stream流、方法引用 一…

《chatwise:DeepSeek的界面部署》

ChatWise:DeepSeek的界面部署 摘要 本文详细描述了DeepSeek公司针对其核心业务系统进行的界面部署工作。从需求分析到技术实现,再到测试与优化,全面阐述了整个部署过程中的关键步骤和解决方案。通过本文,读者可以深入了解DeepSee…

单节锂电池外部供电自动切换的电路学习

文章目录 前言一、原理分析:①当VBUS处有外部电源输入时②当VBUS处无外部电源输入时 二、器件选择1、二极管2、MOS管3、其他 总结 前言 学习一种广泛应用的锂电池供电自动切换电路 电路存在外部电源时,优先使用外部电源供电,并为电池供电&…

数据结构-堆和PriorityQueue

1.堆&#xff08;Heap&#xff09; 1.1堆的概念 堆是一种非常重要的数据结构&#xff0c;通常被实现为一种特殊的完全二叉树 如果有一个关键码的集合K{k0,k1,k2,...,kn-1}&#xff0c;把它所有的元素按照完全二叉树的顺序存储在一个一维数组中&#xff0c;如果满足ki<k2i…

如何打造一个更友好的网站结构?

在SEO优化中&#xff0c;网站的结构往往被忽略&#xff0c;但它其实是决定谷歌爬虫抓取效率的关键因素之一。一个清晰、逻辑合理的网站结构&#xff0c;不仅能让用户更方便地找到他们需要的信息&#xff0c;还能提升搜索引擎的抓取效率 理想的网站结构应该像一棵树&#xff0c;…

每日Attention学习20——Group Shuffle Attention

模块出处 [MICCAI 24] [link] LB-UNet: A Lightweight Boundary-Assisted UNet for Skin Lesion Segmentation 模块名称 Group Shuffle Attention (GSA) 模块作用 轻量特征学习 模块结构 模块特点 使用分组(Group)卷积降低计算量引入External Attention机制更好的学习特征S…

VUE之组件通信(二)

1、v-model v-model的底层原理&#xff1a;是:value值和input事件的结合 $event到底是啥&#xff1f;啥时候能.target 对于原生事件&#xff0c;$event就是事件对象 &#xff0c;能.target对应自定义事件&#xff0c;$event就是触发事件时&#xff0c;所传递的数据&#xff…

[x86 ubuntu22.04]进入S4失败

目录 1 问题描述 2 解决过程 2.1 查看内核日志 2.2 新建一个交换分区 2.3 指定交换分区的位置 1 问题描述 CPU&#xff1a;G6900E OS&#xff1a;ubuntu22.04 Kernel&#xff1a;6.8.0-49-generic 使用“echo disk > /sys/power/state”命令进入 S4&#xff0c;但是无法…

idea隐藏无关文件

idea隐藏无关文件 如果你想隐藏某些特定类型的文件&#xff08;例如 .log 文件或 .tmp 文件&#xff09;&#xff0c;可以通过以下步骤设置&#xff1a; 打开设置 在菜单栏中选择 File > Settings&#xff08;Windows/Linux&#xff09;或 IntelliJ IDEA > Preference…

文献阅读 250205-Global patterns and drivers of tropical aboveground carbon changes

Global patterns and drivers of tropical aboveground carbon changes 来自 <Global patterns and drivers of tropical aboveground carbon changes | Nature Climate Change> 热带地上碳变化的全球模式和驱动因素 ## Abstract: Tropical terrestrial ecosystems play …