数字识别问题

news2024/11/14 21:01:35

文章目录

  • 6.1 MNIST数据处理
  • 6.2.1 训练数据
  • 6.2.2 变量管理
  • 6.3.1 保存模型
  • 6.3.1 加载计算图
  • 6.3.1 加载模型
  • 6.3.2 导出元图

6.1 MNIST数据处理

在直接在第6章的目录下面创建文件
在这里插入图片描述
compat.v1.是tensorflow2.x的语法,全部删掉
在这里插入图片描述
删除compat.v1.后的代码

# -*- coding: utf-8 -*-
"""
Created on Sat May  7 21:29:18 2022

@author: HRH
"""


from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
tf.disable_eager_execution()
file = "./MNIST"
mnist = input_data.read_data_sets(file, one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
Y = tf.placeholder(tf.float32, [None, 10])
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=Y, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, Y: batch_ys})
    if i%100 ==0:
       print(sess.run(accuracy, feed_dict={x: mnist.test.images, Y:    
mnist.test.labels}))  
print ("优化完成")
print ("模型的准确率为",sess.run(accuracy, feed_dict = {x:mnist.test.images, Y: mnist.test.labels}))

运行结果:
在这里插入图片描述

6.2.1 训练数据

需要运行这三个程序,运行顺序如后标所示
在这里插入图片描述

第一个程序:
在这里插入图片描述
mnist_inference.py 文件代码

import tensorflow as tf
tf.disable_eager_execution()
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500

def get_weight_variable(shape, regularizer):
    weights = tf.get_variable("weights", shape,initializer=tf.truncated_normal_initializer(stddev=0.1))
    if regularizer != None:
        tf.add_to_collection('losses', regularizer(weights))
    return weights
def inference(input_tensor, regularizer):
    with tf.variable_scope('layer1'):
        weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
        biases = tf.get_variable("biases", [LAYER1_NODE],initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights)+biases)
    with tf.variable_scope('layer2'):
        weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
        biases = tf.get_variable("biases", [OUTPUT_NODE],initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1, weights) + biases
    return layer2

第二个程序:
在这里插入图片描述
在这里插入图片描述
300次数据训练完成:
在这里插入图片描述
mnist_train.py 文件代码

import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
tf.disable_eager_execution()
tf.reset_default_graph()
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 300
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./"
MODEL_NAME = "model.ckpt"
def train(mnist):
    print("开始训练!")
    # 定义输入输出placeholder。
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE],name='x-input')
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE],name='y-input')
#    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularizer = tf.keras.regularizers.l2(REGULARIZATION_RATE)
    # 直接使用mnist_inference.py中定义的前向传播过程
    y = mnist_inference.inference(x, regularizer)
    global_step = tf.Variable(0, trainable=False)
    # 定义损失函数、学习率、滑动平均操作以及训练过程
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variable_averages_op = variable_averages.apply(tf.trainable_variables())
    # 交叉熵与softmax函数一起使用
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE,LEARNING_RATE_DECAY)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    with tf.control_dependencies([train_step, variable_averages_op]):
        train_op = tf.no_op(name='train')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        print("变量初始化!")
        tf.global_variables_initializer().run()
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step],feed_dict={x: xs, y_: ys})
            # 每1000轮保存一次模型
            #if i+1 % 10 == 0:
            print("After %d training step(s), loss on training ""batch is %g." % (step, loss_value))
            saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),global_step=global_step)

def main(argv=None):
    print("进入主函数!")
    mnist = input_data.read_data_sets(r".\mnist", one_hot=True)
    print("准备训练!")
    train(mnist)

if __name__ == "__main__":
    tf.app.run()

第三个程序:
在这里插入图片描述
mnist_eval.py文件

import time
import tensorflow.compat.v1 as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
tf.disable_eager_execution()
tf.reset_default_graph()
EVAL_INTERVAL_SECS = 10
def evaluate(mnist):
    with tf.Graph().as_default() as g:
        #定义输入与输出的格式
        x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
        validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
	       #直接调用封装好的函数来计算前向传播的结果
        y = mnist_inference.inference(x, None)
	 	  #计算正确率
        correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32))
	       #通过变量重命名的方式加载模型
        variable_averages = tf.train.ExponentialMovingAverage(0.99)
        variable_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variable_to_restore)
        #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化
        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(r"./")
                if ckpt and ckpt.model_checkpoint_path:
                    #load the model
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
                    print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score))
                    return
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(EVAL_INTERVAL_SECS)
def main(argv=None):
    mnist = input_data.read_data_sets(r".\mnist", one_hot=True)
    evaluate(mnist)
if __name__ == '__main__':
    tf.app.run()

训练结果:
在这里插入图片描述

6.2.2 变量管理

在这里插入图片描述
程序报错:
在这里插入图片描述
报错解决:
在这里插入图片描述
运行结果:
在这里插入图片描述

import tensorflow as tf
tf.disable_v2_behavior()
tf.reset_default_graph()
# 在名字为foo的命名空间内创建名字为v的变量
with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1], initializer=tf.constant_initializer(1.0))
'''
# 因为命名空间foo内已经存在变量v,再次创建则报错
with tf.variable_scope("foo"):
    v = tf.get_variable("v", [1])
# ValueError: Variable foo/v already exists, disallowed.
# Did you mean to set reuse=True in VarScope?
'''
# 将参数reuse参数设置为True,则tf.get_variable可直接获取已声明的变量
with tf.variable_scope("foo", reuse=True):
    v1 = tf.get_variable("v", [1])
    print(v == v1) # True
'''
# 当reuse=True时,tf.get_variable只能获取指定命名空间内的已创建的变量
with tf.variable_scope("bar", reuse=True):
    v2 = tf.get_variable("v", [1])
# ValueError: Variable bar/v does not exist, or was not created with
# tf.get_variable(). Did you mean to set reuse=None in VarScope?
'''
with tf.variable_scope("root"):
    # 通过tf.get_variable_scope().reuse函数获取当前上下文管理器内的reuse参数取值
    print(tf.get_variable_scope().reuse) # False
    with tf.variable_scope("foo1", reuse=True):
        print(tf.get_variable_scope().reuse) # True
        with tf.variable_scope("bar1"):
            # 嵌套在上下文管理器foo1内的bar1内未指定reuse参数,则保持与外层一致
            print(tf.get_variable_scope().reuse) # True
    print(tf.get_variable_scope().reuse) # False
# tf.variable_scope函数提供了一个管理变量命名空间的方式
u1 = tf.get_variable("u", [1])
print(u1.name) 
with tf.variable_scope("foou"):
    u2 = tf.get_variable("u", [1])
    print(u2.name) 
with tf.variable_scope("foou"):
    with tf.variable_scope("baru"):
        u3 = tf.get_variable("u", [1])
        print(u3.name)
    u4 = tf.get_variable("u1", [1])
    print(u4.name) 
# 可直接通过带命名空间名称的变量名来获取其命名空间下的变量
with tf.variable_scope("", reuse=True):
    u5 = tf.get_variable("foou/baru/u", [1])
    print(u5.name)  
    print(u5 == u3) 
    u6 = tf.get_variable("foou/u1", [1])
    print(u6.name) 
    print(u6 == u4) 

6.3.1 保存模型

新建一个Model文件夹用来保存模型
在这里插入图片描述
在这里插入图片描述
运行程序:
在这里插入图片描述

6.3.1 加载计算图

在这里插入图片描述

6.3.1 加载模型

在这里插入图片描述

6.3.2 导出元图

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/510508.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SCI一区】考虑P2G和碳捕集设备的热电联供综合能源系统优化调度模型(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

学术必备的21个论文网站,建议收藏!

1、综合型论文网站(国内) (1)知网 介绍:国内知名度最高的网站,拥有上亿篇各种论文期刊,包含中国学术文献、 外文文献、学位论文、报纸、会议、年鉴、工具书等各类资源统一检索、统一导 航、…

第四章 图像的形态学操作

文章目录 前言一、阈值控制二、腐蚀与膨胀1.腐蚀2.膨胀3.形态学操作 总结 前言 前面讲解了图像基础理论、图像的变换以及图像滤波等操作,本章,将会介绍图像的形态学操作。 图像的形态学指的是一组数学方法和工具,用于图像分析和处理。形态学…

(7)Qt---文件IO

目录 1. QFileDialog 文件选择对话框** 2. QFileInfo 文件信息类** 3. QFile 文件读写类*** 4. UI与耗时操作** 5. QThread 线程类 5.1 复现阻塞 5.2 新建并启动子线程 5.3 异步刷新 5.4 停止线程 1. QFileDialog 文件选择对话框** 操作系统会提供一个统一样式的文件选择对话框…

从本地到云端:豆瓣如何使用 JuiceFS 实现统一的数据存储

豆瓣成立于 2005 年,是中国最早的社交网站之一。在 2009 到 2019 的十年间,豆瓣数据平台经历了几轮变迁,形成了 DPark Mesos MooseFS 的架构。 由机房全面上云的过程中,原有这套架构并不能很好的利用云的特性,豆瓣需…

少林派问题汇总

少林派问题汇总: Q: A:缺少bmodel,没有指定bmodel的路径,测试图片不在同一文件路径下 复制过来就解决了 Q: docker容器下运行./install_lib.sh nntc会rm不到文件怎么回事? A:文件已经被删除 Q: 我将pytorch的模型用export工具转换成.torch…

聚焦丨酷雷曼荣列XRMA联盟成员单位

自“元宇宙”概念兴起之初,酷雷曼VR所属北京同创蓝天云科技有限公司就积极布局、探索和实践。2022年12月,酷雷曼VR成功加入虚拟现实与元宇宙产业联盟(XRMA),正式被接纳为联盟成员单位,意味着酷雷曼公司将进…

详细版易学版TypeScript - 元组 枚举详解

一、元组(Tuple) 数组:合并了相同类型的对象 const myArr: Array<number> [1, 2, 3]; 元组(Tuple):合并了不同类型的对象 // 定义元组时就要确定好数据的类型&#xff0c;并一一对应 const tuple: [number, string] [12, "hi"]; // 添加内容时&#xff0c;不…

代码随想录算法训练营day35 | 860.柠檬水找零,406.根据身高重建队列,452. 用最少数量的箭引爆气球

代码随想录算法训练营day35 | 860.柠檬水找零&#xff0c;406.根据身高重建队列&#xff0c;452. 用最少数量的箭引爆气球 860.柠檬水找零406.根据身高重建队列452. 用最少数量的箭引爆气球 860.柠檬水找零 教程视频&#xff1a;https://www.bilibili.com/video/BV12x4y1j7DD/…

3.SpringBoot开发实用篇

SpringBoot开发实用篇 ​ 开发实用篇中因为牵扯到SpringBoot整合各种各样的技术&#xff0c;由于不是每个小伙伴对各种技术都有所掌握&#xff0c;所以在整合每一个技术之前&#xff0c;都会做一个快速的普及&#xff0c;这样的话内容整个开发实用篇所包含的内容就会比较多。各…

推荐系统综述

这里写目录标题 推荐系统架构1、传统推荐方式1.1 基于内容推荐&#xff08;Content-Based recommendation&#xff0c;CB&#xff09;1.2 协同过滤推荐&#xff08;Collaborative Filtering recommendation&#xff0c; CF&#xff09;1.2.0 UserCF举例&#xff1a;1. 2. 1 基于…

window-2016服务器;服务——活动目录

♥️作者&#xff1a;小刘在C站 ♥️个人主页&#xff1a;小刘主页 ♥️每天分享云计算网络运维课堂笔记&#xff0c;努力不一定有收获&#xff0c;但一定会有收获加油&#xff01;一起努力&#xff0c;共赴美好人生&#xff01; ♥️树高千尺&#xff0c;落叶归根人生不易&…

OpenCL编程指南-2.1HelloWorld

在windows下编写HelloWorld 按照前面文章搭建好OpenCL的环境https://blog.csdn.net/qq_36314864/article/details/130513584 main函数完成以下操作&#xff1a; 1&#xff09;在第一个可用平台上创建OpenCL上下文 2&#xff09;在第一个可用设备上创建命令队列 3&#xff09;…

第二期 | ICASSP 2023 论文预讲会

ICASSP 2023 论文预讲会是由CCF语音对话与听觉专委会、语音之家主办&#xff0c;旨在为学者们提供更多的交流机会&#xff0c;更方便、快捷地了解领域前沿。活动将邀请 ICASSP 2023 录用论文的作者进行报告交流。 ICASSP 2023 论文预讲会邀请到清华大学人机语音交互实验室&…

单细胞跨模态分析综述

单细胞技术的最新进展使跨模态和组织位置的细胞高通量分子分析成为可能。单细胞转录组数据现在可以通过染色质可及性、表面蛋白表达、适应性免疫受体库分析和空间信息进行补充。跨模态单细胞数据的可用性越来越高&#xff0c;推动出新的计算方法&#xff0c;以帮助科学家获得生…

图的遍历——深度优先搜索(DFS)与广度优先搜索(BFS)(附带C语言源码)

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️数据结构与算法】 学习名言&#xff1a;天子重英豪&#xff0c;文章教儿曹。万般皆下品&#xff0c;惟有读书高——《神童诗劝学》 系列文章目录 第一章 ❤️ 学前知识 第二章 ❤️ 单向链表 第三章…

mysql数据迁移与同步常用解决方案总结

目录 一、前言 二、数据迁移场景 2.1 整库迁移 2.2 表数据迁移 2.3 mysql版本变更 2.4 mysql数据迁移至其他存储介质 2.5 自建数据到上云环境 2.6 mysql数据到其他国产数据库 三、数据库物理迁移实施方案 3.1 数据库物理迁移概述 3.1.1 物理迁移适用场景 3.1.2 物理…

杂记 2023.5.10

目录 韦伯和斯托亚科维奇是谁&#xff1f; 介绍一下kali FastDFS和Sentinel是什么&#xff1f; Inferno 找工作的影响因素 1. 背景&#xff1a; 2. 学习过程&#xff1a; 2.1 计算机基础&#xff1a; 2.2 语言&#xff1a; 2.3 数据库等&#xff1a; 2.4 JVM&#…

月薪17k需要什么水平?98年测试员的面试全过程…

我的情况 大概介绍一下个人情况&#xff0c;男&#xff0c;本科&#xff0c;三年多测试工作经验&#xff0c;懂python&#xff0c;会写脚本&#xff0c;会selenium&#xff0c;会性能&#xff0c;然而到今天都没有收到一份offer&#xff01;从年后就开始准备简历&#xff0c;年…

Linux操作系统如何查看CPU型号信息?一条命令搞定

Linux操作系统服务器如何查看CPU处理器信息&#xff1f;使用命令cat /proc/cpuinfo可以查看CPU详细信息&#xff0c;包括CPU核数、逻辑CPU、物理CPU个数、CPU是否启用超线程等&#xff0c;阿里云服务器网分享Linux服务器查看CPU信息命令&#xff1a; 目录 Linux服务器查看CPU…