6.2.1mnist _eval

news2025/1/28 1:11:57

之前在调试6.2.1mnist _eval代码的时候,出现了下面的错误

//下面不阐述本人遇到的错误,直接告诉大家解决办法(以老师给的源代码进行演示)

首先,打开第6章的源代码

 

//点击程序与数据拆分的文件夹, 并将三个文件夹复制到相关路径

 //D:\anaconda\envs\yxy\Lib\site-packages,这是本人的路径,envs\yxy是自己创的那个环境

 

 还需要把MNIST文件夹中的4个压缩包放到如下目录,D:\anaconda\envs\yxy\Lib\site-packages\tensorboard\mnist

 

//之所以放到这个目录是因为

 //当然这里面的路径根据个人情况去改,也可以直接简化为“.\mnist”,系统会自己去寻找。

随后,打开jupyter,先运行mnist_inference代码

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500

def get_weight_variable(shape, regularizer):
    weights = tf.get_variable("weights", shape,initializer=tf.truncated_normal_initializer(stddev=0.1))
    if regularizer != None:
        tf.add_to_collection('losses', regularizer(weights))
    return weights
def inference(input_tensor, regularizer):
    with tf.variable_scope('layer1'):
        weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
        biases = tf.get_variable("biases", [LAYER1_NODE],initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights)+biases)
    with tf.variable_scope('layer2'):
        weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
        biases = tf.get_variable("biases", [OUTPUT_NODE],initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1, weights) + biases
    return layer2

再运行mnist_train代码,循环次数可以更改,书上结果是30000次,不过为了节约时间,大家可以改小一点。

import os
import tensorflow.compat.v1 as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
tf.disable_eager_execution()
tf.reset_default_graph()
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000//循环次数
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./"
MODEL_NAME = "model.ckpt"
def train(mnist):
    print("开始训练!")
    # 定义输入输出placeholder。
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE],name='x-input')
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE],name='y-input')
#    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularizer = tf.keras.regularizers.l2(REGULARIZATION_RATE)
    # 直接使用mnist_inference.py中定义的前向传播过程
    y = mnist_inference.inference(x, regularizer)
    global_step = tf.Variable(0, trainable=False)
    # 定义损失函数、学习率、滑动平均操作以及训练过程
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variable_averages_op = variable_averages.apply(tf.trainable_variables())
    # 交叉熵与softmax函数一起使用
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE,LEARNING_RATE_DECAY)
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    with tf.control_dependencies([train_step, variable_averages_op]):
        train_op = tf.no_op(name='train')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        print("变量初始化!")
        tf.global_variables_initializer().run()
        for i in range(TRAINING_STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step],feed_dict={x: xs, y_: ys})
            # 每1000轮保存一次模型
            #if i+1 % 10 == 0:
            print("After %d training step(s), loss on training ""batch is %g." % (step, loss_value))
            saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),global_step=global_step)

def main(argv=None):
    print("进入主函数!")
    mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True)
    print("准备训练!")
    train(mnist)

if __name__ == "__main__":
    tf.app.run()

 

 

最后运行mnist_eval代码

import time
import tensorflow.compat.v1 as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
tf.disable_eager_execution()
tf.reset_default_graph()
EVAL_INTERVAL_SECS = 10
def evaluate(mnist):
    with tf.Graph().as_default() as g:
        #定义输入与输出的格式
        x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
        validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
	       #直接调用封装好的函数来计算前向传播的结果
        y = mnist_inference.inference(x, None)
	 	  #计算正确率
        correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32))
	       #通过变量重命名的方式加载模型
        variable_averages = tf.train.ExponentialMovingAverage(0.99)
        variable_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variable_to_restore)
        #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化
        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(r"./")
                if ckpt and ckpt.model_checkpoint_path:
                    #load the model
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
                    print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score))
                    return
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(EVAL_INTERVAL_SECS)
def main(argv=None):
    mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True)
    evaluate(mnist)
if __name__ == '__main__':
    tf.app.run()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/493063.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3、Flutter项目搭建

一、搭建项目 1.1 搭建空壳项目 接上篇的项目搭建、本篇将继续搭建各个界面.当BottomNavigationBar搭建起来后,在各个界面,没有显示对应的元素,因此我们在包含它的Scaffold中,添加body,这样让每个界面撑起来.每次点击就切换对应的界面. 那么我们创建一个_RootPageState中的私…

【Python】scikit-plot可视化模型(含源代码)

文章目录 一、前言二、功能1:评估指标可视化2.1 scikitplot.metrics.plot_confusion_matrix2.2 scikitplot.metrics.plot_roc2.3 scikitplot.metrics.plot_ks_statistic2.4 scikitplot.metrics.plot_precision_recall2.5 scikitplot.metrics.plot_silhouette2.6 sci…

操作系统学习01

1、什么是操作系统? 通过以下四点可以概括操作系统到底是什么: 操作系统(Operating System,简称 OS)是管理计算机硬件与软件资源的程序,是计算机的基石。操作系统本质上是一个运行在计算机上的软件程序 &a…

微前端 qiankun@2.10.5 源码分析(一)

微前端 qiankun2.10.5 源码分析(一) 前言 微前端是一种多个团队通过独立发布功能的方式来共同构建现代化 web 应用的技术手段及方法策略。 Techniques, strategies and recipes for building a modern web app with multiple teams that can ship feat…

Figma转换为sketch,分享这3款工具

在我们的设计工作中,我们经常会遇到各种各样的设计文件相互转换的问题。 你经常为此头疼吗?当你遇到Figma转换Sketch文件的问题时,你是如何解决的?Figma转换Sketch文件有工具吗? 根据众多设计师的经验,本…

在竞争激烈的移动应用市场中获得成功,掌握决胜Framework技术

为何要学习framework? Framework,指的是对应用程序开发所需的核心工具和组件的封装和提供。在Android开发中,Framework是整个开发过程中的核心组成部分,提供了许多功能和服务,包括UI组件、数据存储、网络通信、多媒体…

第二十四章 策略模式

文章目录 前言传统方式解决鸭子问题完整代码抽象鸭子类野鸭子类北京鸭子类玩具鸭子类 一、策略模式基本介绍二、策略模式解决鸭子问题完整代码飞翔接口 FlyBehavior飞翔接口的子类实现飞翔技术高超 GoodFlyBehavior不会飞翔 NoFlyBehavior飞翔技术一般 BadFlyBehavior其他行为接…

文献阅读 Meta-SR: A Magnification-Arbitrary Network for Super-Resolution

题目 Meta-SR: A Magnification-Arbitrary Network for Super-Resolution Meta-SR: 用于超分辨率的任何放大网络 摘要 由于DCNN的发展,最近关于超分辨率的研究取得了巨大成功。然而,任意比例因子的超分辨率长期以来一直被忽视。以往的研究者大多将不同…

Stable-Diffusion AI画画本地搭建详细步骤

ChatGPT出来后,第一次感觉到人工智能真的可能要来了,因此也顺便尝试了下开源AI画画的搭建。网络上写的教程总是不那么面面俱到,因此本文参考了3篇文章才成功把Stable-Diffusion 本地搭建搭建了起来。参考教程在文末。 本文是本地搭建AI画画&a…

C/C++内存泄露检查利器—valgrind

1、Valgrind概述 Valgrind是一套Linux下,开放源代码(GPL V2)的仿真调试工具的集合。 Valgrind由内核(core)以及基于内核的其他调试工具组成。内核类似于一个框架(framework),它模拟…

Android中的GPS开发

GPS简介 Gobal Positioning System,全球定位系统,是美国在20世纪70年代研制的一种以人造地球卫星为基础的高精度无线电导航的定位系统,它在全球任何地方以及近地空间都能够提供准确的地理位置、车行速度及精确的时间信息;它是具有…

2023年房地产抵押贷款研究报告

第一章 概述 房地产抵押贷款是一种以房地产为抵押品的贷款形式,包括个人和企业两种情况。个人房地产抵押贷款是指个人将名下房产作为抵押品向银行或其他金融机构申请贷款,而企业房地产抵押贷款则是指企业将自己名下的商业房产作为抵押品向金融机构申请贷…

202309读书笔记|《野性之美:非洲野生动物初窥》——走进自然界的野性之美

《野性之美: 非洲野生动物初窥》微读的一本书,图片居多,非常有视觉上的震撼。拍摄者也是我们孙姓的一员,孙长智。正如作者所说,与自然对话,你会感悟到生命之美、竞争之美、进化之美、和谐之美! 我喜欢自然…

SPSS如何绘制常用统计图之案例实训?

文章目录 0.引言1.绘制简单条形图2.绘制分类条形图3.绘制分段条形图4.绘制简单线图5.绘制多重线图6.绘制垂直线图7.绘制简单面积图8.绘制堆积面积图9.绘制饼图10.绘制直方图11.绘制简单散点图12.绘制重叠散点图13.绘制矩阵散点图14.绘制三维散点图15.绘制简单箱图16.绘制分类箱…

【markdown工具配合图床】PicGo图床配置教程,一秒读懂配置

前言 看到这篇文章的大佬,我默认大家都会配置git,已经配置好ssh公钥。 此时你看到的这篇文章就是基于markdown工具(VSCode,Typora)编写的。 PicGo作为图床转换工具,并配合gitee作为图片服务器&#xff0…

java元注解和自定义注解的区别

Java的元注解和自定义注解是两个不同的概念。 元注解是Java内置的一组用于修饰其他注解的注解,包括Retention、Target、Inherited和Documented。它们可以控制被修饰的注解的保留策略、目标范围、是否继承等属性,并且可以在编写自定义注解时使用。 Retent…

国考省考结构化面试:综合分析题,社会现象(积极消极政策)、名言哲理(警句观点启示)、漫画反驳题等

国考省考结构化面试:综合分析题,社会现象(积极消极政策)、名言哲理(警句观点启示)、漫画反驳题等 2022找工作是学历、能力和运气的超强结合体! 公务员特招重点就是专业技能,附带行测和申论&…

【Java数据结构】优先级队列(堆)

优先级队列(堆) 概念模拟实现堆的概念堆的存储方式堆的创建向下调整堆的创建建堆的时间复杂度 堆的插入和删除堆的插入堆的删除 用堆模拟实现优先级队列 常用接口PriorityQueue的特性PriorityQueue常用接口介绍构造方法插入/删除/获取优先级最高的元素 P…

孙溟㠭篆刻,红木上的‘’椎凿稚趣‘’

了解中国传统篆刻的人,一定知道篆刻作品中追求的“金石气”。作为拥有3700多年历史的中国传统艺术,篆刻艺术是将书法(主要是篆书)和镌刻(包括凿、铸)相结合,制作印章,亦是汉字独有的…

Vivado 仿真器中以批处理或脚本模式(Batch or Scripted Mode)进行仿真

以下说明来自ug900:在 Vivado 仿真器中以批处理或脚本模式进行仿真 具体可以内容可自行查找 其中代码运行截图为自己实践的实例 Note: xelab, xvlog and xvhdl are not Tcl commands. The xvlog, xvhdl, xelab are Vivado-independent compiler executables. Hence, there is…