tensorflow2 minist手写数字识别数据训练

news2025/3/1 21:03:46

✨ 博客主页:小小马车夫的主页
✨ 所属专栏:Tensorflow

请添加图片描述

文章目录

  • 前言
  • 一、tenosrflow minist手写数字识别代码
  • 二、输出
  • 三、参考资料
  • 总结


前言

刚开始学习tensorflow, 首先接触的是minist手写数字识别,用的梯度下降算法,记录一下以备后续复习和供其他初学者参考,如有错误请不吝指正,万分感谢。

环境:

  • python 3.9.13
  • Tensorflow 2.11.0
  • Tensorboard 2.11.0

一、tenosrflow minist手写数字识别代码

将说明加在代码注释,方便查看复习。

import tensorflow as tf
from tensorflow import keras
from keras import layers, optimizers, datasets

#加载minist数据集,分成训练集和测试集,每个样本包含图像和标签
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets', x.shape, y.shape, x.min(), y.min())
#训练集图像数据归一化到0-1之前
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
#构建数据集对象
db = tf.data.Dataset.from_tensor_slices((x, y))
#批量训练,并行计算一次32个样本、所有数据集迭代20次
db = db.batch(32).repeat(10)
#构建Sequential窗口,一共3层网络,并且前一个网络的输出作为后一个网络的输入
model = keras.Sequential([
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(10)
])
#指定输入大小
model.build(input_shape=(None, 28*28))
#打印出网络的结构和参数量
model.summary()
#optimizers用于更新梯度下降算法参数,0.01为学习率
optimizer = optimizers.SGD(lr=0.01)
#准备率
acc_meter = keras.metrics.Accuracy()
#创建参数文件
summary_writer = tf.summary.create_file_writer('/Users/qcr/tf_log')
#循环数据集
for step, (xx, yy) in enumerate(db):
    #上下文
    with tf.GradientTape() as tape:
        #图像样本大小重置(-1, 28*28)
        xx = tf.reshape(xx, (-1, 28*28))
        #获取输出
        out = model(xx)
        #实际标签转为onehot编码
        y_onehot = tf.one_hot(yy, depth=10)
        #计算误差
        loss = tf.square(out-y_onehot)
        loss = tf.reduce_sum(loss/xx.shape[0])
    #更新准备率
    acc_meter.update_state(tf.argmax(out, axis=1), yy)
    #更新梯度参数
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    #参数存储,便于查看曲线图
    with summary_writer.as_default():
        tf.summary.scalar('train-loss', float(loss), step=step)
        tf.summary.scalar('test-acc', acc_meter.result().numpy(), step=step)
        #tf.summary.image('val-onebyone-images', val)
    if step % 1000 == 0:
        print(step, 'loss:', float(loss), 'acc:', acc_meter.result().numpy())
        acc_meter.reset_states()

二、输出

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 256)               200960    
                                                                 
 dense_1 (Dense)             (None, 128)               32896     
                                                                 
 dense_2 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0

0 loss: 1.6363120079040527 acc: 0.0625
1000 loss: 0.3189510107040405 acc: 0.83390623
2000 loss: 0.2195253074169159 acc: 0.91753125
3000 loss: 0.2733377516269684 acc: 0.93
4000 loss: 0.20172631740570068 acc: 0.9415
5000 loss: 0.13919278979301453 acc: 0.94378126
6000 loss: 0.14041364192962646 acc: 0.951625
7000 loss: 0.0935342013835907 acc: 0.9514375
8000 loss: 0.1644362509250641 acc: 0.95728123
9000 loss: 0.11363211274147034 acc: 0.9559063
10000 loss: 0.15755562484264374 acc: 0.9628125
11000 loss: 0.0880645364522934 acc: 0.959375
12000 loss: 0.08858028799295425 acc: 0.9657813
13000 loss: 0.0917932391166687 acc: 0.96296877
14000 loss: 0.06503693014383316 acc: 0.9683125
15000 loss: 0.09167198836803436 acc: 0.9665625
16000 loss: 0.1386248767375946 acc: 0.96834373
17000 loss: 0.10692787915468216 acc: 0.96953124
18000 loss: 0.10871071368455887 acc: 0.9697813

误差和准确率曲线

tensorflow2 loss
tensorflow2 acc

三、参考资料

[1] 《Tensorflow深度学习》


总结

以上就是本次的内容,来总结一下吧:
主要介绍了tensorflow2梯度下降算法实现minist手写数字数据集的训练,并对结果进行可视化展示。

如果觉得有些帮助或觉得文章还不错,请关注一下博主,你的关注是我持续写作的动力。另外,如果有什么问题,可以在评论区留言,或者私信博主,博主看到后会第一时间进行回复。
【间歇性的努力和蒙混过日子,都是对之前努力的清零】
欢迎转载,转载请注明出处:https://blog.csdn.net/xxm524/article/details/128054377

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/38773.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux下mysql的三种安装方法

目录1. 离线安装(.tar.gz安装包)2. 离线安装(rpm安装包)3. 在线安装(yum安装)前言安装环境 : Redhat Enterprise Linux 81. 离线安装(.tar.gz安装包) 参考这篇博客 2. 离线安装&am…

HTB-Toolbox

HTB-Toolbox信息收集提权信息收集 使用nmap简单扫描一下网站端口。 21 ftp22 ssh443 https? 去https看看。 网站基本是静态的。因为是https,所以有ssl协议,去看看吧。 more information里面能找到协议。 找到admin.megalogistic.com子域。…

Spring(完整版)

文章目录一、Spring(一)、Spring简介1、Spring概述2、Spring家族3、Spring Framework1、Spring Framework五大功能模块2、Spring Framework特性(二)、控制反转IOC1、IOC容器1、IOC思想2、IOC容器在Spring中的两种实现方式①BeanFactory②ApplicationContext③ApplicationContex…

配置elasticsearch用windows account(AD)登录

配置elasticsearch用windows account(AD)登录编辑es的配置文件创建role mapping文件添加windows account的密码给role mapping和cacert文件正确的权限重启kibana和elasticsearch在kibana页面登陆elasticsearch参考文章:• https://www.elasti…

vue中如何使用swiper以及解决swiper初始化过早的问题

后端的返回的数据用数组接收; 把swiper放到根组件里的mounted,也会出现swiper先初始化,dom再加载的问题: swiper初始化在mounted里执行;setTimeout定时器是从后端请求回来的数据; 代码执行顺序是:当组件创…

Eureka Server配置

01.Eureka Server配置 Eureka Server提供注册服务,各个节点启动后,会在EurekaServer中进行注册,Eureka Server会存储所有可用信息的服务节点,其信息可以在界面中直观的观察到。(服务注册中心 CAP核心理论 一个分布式…

(十六)Spring对事务的支持

文章目录环境事务概述引入事务场景第一步:准备环境第二步:编写持久层第三步:编写业务层第四步:编写Spring配置文件第五步:编写表示层(测试程序)模拟异常Spring对事务的支持Spring事务管理API声明…

mysql-6-主从复制搭建

1 总结 1:主从复制最大缺陷就是延迟。 2 搭建前的准备 2.1复制的基本原则 每个slave只有一个master每个slave只能有一个唯一的服务器ID每个master可以有多个slavemysql版本尽量一致,防止出问题。两台服务能ping通MySQL主从是基于binlog的,主上…

Face Global | 创龙科技2款新品登陆TI全球官网

日前,创龙科技AM62x、AM64x处理器平台齐登TI全球官方网站,向全球TI用户提供高可靠性的工业核心板以及工业评估套件。 图 1 TI全球官网截图-AM64x 图 2 TI全球官网截图-AM62x 创龙科技(Tronlong)作为TI中国官方合作伙伴,自2013年成立以来,已基于TI Sitara、C6000、DaVinci、…

【网络安全必看】如何提升自身WEB渗透能力?

前言 web渗透这个东西学起来如果没有头绪和路线的话,是非常烧脑的。 理清web渗透学习思路,把自己的学习方案和需要学习的点全部整理,你会发现突然渗透思路就有点眉目了。 程序员之间流行一个词,叫35岁危机,&#xf…

OSPF路由策略引入

功能介绍: distribute-list 分发列表,通过distribute-list 工具对路由更新进行控制,只能进行路由条目过滤,不能修改路由的属性。 一、组网要求 在R2上把rip路由重分发进ospf,并且在重分发时进行路由过滤,…

【云原生】Docker-compose单机容器集群编排

内容预知 1.Compose的相关知识 1. Compose的相关概念 2. 为何需要docker-compose docker镜像管理的问题 Docker Compose的解决方案 3. Compose的特征 2. Docker-compose的安装 3. Compose配置常用字段和YAML 文件编写 3.1 YAML 文件格式及编写注意事项 (1&…

[附源码]Python计算机毕业设计儿童闲置物品交易网站

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

Java面试题-为什么重写equals就一定要重写hashCode方法呢?

目录 1、为什么要重写equals 方法 2、hashCode 方法 3、为什么要一起重写? 4 原因分析 总结 先放结论: hashCode 和 equals 两个方法是用来协同判断两个对象是否相等的,采用这种方式的原因是可以提高程序插入和查询的速度。如果只重写equ…

第五届传智杯-初赛【A组-题解】

B题: 题目背景 【 题目背景和题目描述的两个题面是完全等价的,您可以选择阅读其中一部分。】 专攻超统一物理学的莲子,对机械结构的运动颇有了解。如下图所示,是一个三进制加法计算器的(超简化)示意图。…

idea,web开发中jsp页面中不提示控制层的请求地址

随着开发的进行,打开spring配置文件会有如下提示 同时工程管理里如下 删掉后,发现打开sping配置文件不告警了,可是jsp页面中也没有了地址的提示 这个提示没有了 正确的做法是删掉Spring Application Context 因为其他配置文件都导入App…

Java_接口

目录 1.接口的语法规则 2.接口使用 3.接口特性 4.实现多个接口 1)下面通过类来表示一组动物; 2)另外再提供一组接口, 分别表示 "会跑的", "会飞的", "会游泳的"; 3)接下来我们创建…

黑马点评--达人探店

查看探店笔记: private void queryBlogUser(Blog blog) {Long userId blog.getUserId();User user userService.getById(userId);blog.setName(user.getNickName());blog.setIcon(user.getIcon()); }Override public Result queryBlogById(Long id) {//1.查询blo…

首1标准型和尾1标准型

目录 (1)系统的传递函数; (2) 系统的增益; (3) 系统的特征根及相应的模态; (4) 画出对应的零极点图; (5) 求系统的单位脉冲响应&#…

Linux下JAVA使用JNA调用C++的动态链接库(g++或者gcc编译的.so文件)

目录 前言 一、准备工作 二、JAVA项目加载JNA 三、JNA的使用 3.1 生成.so文件 3.1.1 gcc生成的.so 3.1.2 g生成的.so 3.2 JNA调用.so 四、JAVA与C的类型对应 五、总结 前言 在没上班之前,我曾在CSDN写过《程序员的自我修养》的读书…