TenorFlow多层感知机识别手写体

news2024/11/22 8:55:29

文章目录

  • 数据准备
  • 建立模型
          • 建立输入层 x
          • 建立隐藏层h1
          • 建立隐藏层h2
          • 建立输出层
  • 定义训练方式
          • 建立训练数据label真实值 placeholder
          • 定义loss function
          • 选择optimizer
  • 定义评估模型的准确率
          • 计算每一项数据是否正确预测
          • 将计算预测正确结果,加总平均
  • 开始训练
          • 画出误差执行结果
          • 画出准确率执行结果
  • 评估模型的准确率
  • 进行预测
  • 找出预测错误

GITHUB地址https://github.com/fz861062923/TensorFlow
注意下载数据连接的是外网,有一股神秘力量让你403

数据准备

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters


WARNING:tensorflow:From <ipython-input-1-2ee827ab903d>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
print('train images     :', mnist.train.images.shape,
      'labels:'           , mnist.train.labels.shape)
print('validation images:', mnist.validation.images.shape,
      ' labels:'          , mnist.validation.labels.shape)
print('test images      :', mnist.test.images.shape,
      'labels:'           , mnist.test.labels.shape)
train images     : (55000, 784) labels: (55000, 10)
validation images: (5000, 784)  labels: (5000, 10)
test images      : (10000, 784) labels: (10000, 10)

建立模型

def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为None
    W = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重W
    b = tf.Variable(tf.random_normal([1, output_dim]))
    XWb = tf.matmul(inputs, W) + b
    if activation is None:
        outputs = XWb
    else:
        outputs = activation(XWb)
    return outputs
建立输入层 x
x = tf.placeholder("float", [None, 784])
建立隐藏层h1
h1=layer(output_dim=1000,input_dim=784,
         inputs=x ,activation=tf.nn.relu)  

建立隐藏层h2
h2=layer(output_dim=1000,input_dim=1000,
         inputs=h1 ,activation=tf.nn.relu)  
建立输出层
y_predict=layer(output_dim=10,input_dim=1000,
                inputs=h2,activation=None)

定义训练方式

建立训练数据label真实值 placeholder
y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function = tf.reduce_mean(
                   tf.nn.softmax_cross_entropy_with_logits_v2
                       (logits=y_predict , 
                        labels=y_label))
选择optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \
                    .minimize(loss_function)
#使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化

定义评估模型的准确率

计算每一项数据是否正确预测
correct_prediction = tf.equal(tf.argmax(y_label  , 1),
                              tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
将计算预测正确结果,加总平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

开始训练

trainEpochs = 15#执行15个训练周期
batchSize = 100#每一批的数量为100
totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list=[];accuracy_list=[];loss_list=[];
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):
    #执行15个训练周期
    #每个训练周期执行550批次训练
    for i in range(totalBatchs):
        batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据
        sess.run(optimizer,feed_dict={x: batch_x,
                                      y_label: batch_y})
        
    #使用验证数据计算准确率
    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={x: mnist.validation.images, #验证数据的features
                                   y_label: mnist.validation.labels})#验证数据的label

    epoch_list.append(epoch)
    loss_list.append(loss);accuracy_list.append(acc)    
    
    print("Train Epoch:", '%02d' % (epoch+1), \
          "Loss=","{:.9f}".format(loss)," Accuracy=",acc)
    
duration =time()-startTime
print("Train Finished takes:",duration)        
Train Epoch: 01 Loss= 133.117172241  Accuracy= 0.9194
Train Epoch: 02 Loss= 88.949943542  Accuracy= 0.9392
Train Epoch: 03 Loss= 80.701606750  Accuracy= 0.9446
Train Epoch: 04 Loss= 72.045913696  Accuracy= 0.9506
Train Epoch: 05 Loss= 71.911483765  Accuracy= 0.9502
Train Epoch: 06 Loss= 63.642936707  Accuracy= 0.9558
Train Epoch: 07 Loss= 67.192626953  Accuracy= 0.9494
Train Epoch: 08 Loss= 55.959281921  Accuracy= 0.9618
Train Epoch: 09 Loss= 58.867351532  Accuracy= 0.9592
Train Epoch: 10 Loss= 61.904548645  Accuracy= 0.9612
Train Epoch: 11 Loss= 58.283069611  Accuracy= 0.9608
Train Epoch: 12 Loss= 54.332244873  Accuracy= 0.9646
Train Epoch: 13 Loss= 58.152175903  Accuracy= 0.9624
Train Epoch: 14 Loss= 51.552104950  Accuracy= 0.9688
Train Epoch: 15 Loss= 52.803482056  Accuracy= 0.9678
Train Finished takes: 545.0556836128235
画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1edb8d4c240>

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

画出准确率执行结果
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

评估模型的准确率

print("Accuracy:", sess.run(accuracy,
                           feed_dict={x: mnist.test.images, 
                                      y_label: mnist.test.labels}))
Accuracy: 0.9643

进行预测

prediction_result=sess.run(tf.argmax(y_predict,1),
                           feed_dict={x: mnist.test.images })
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(12, 14)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)
        
        ax.imshow(np.reshape(images[idx],(28, 28)), 
                  cmap='binary')
            
        title= "label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx]) 
            
        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

y_predict_Onehot=sess.run(y_predict,
                          feed_dict={x: mnist.test.images })
y_predict_Onehot[8]
array([-6185.544  , -5329.589  ,  1897.1707 , -3942.7764 ,   347.9809 ,
        5513.258  ,  6735.7153 , -5088.5273 ,   649.2062 ,    69.50408],
      dtype=float32)

找出预测错误

for i in range(400):
    if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
        print("i="+str(i)+"   label=",np.argmax(mnist.test.labels[i]),
              "predict=",prediction_result[i])
i=8   label= 5 predict= 6
i=18   label= 3 predict= 8
i=149   label= 2 predict= 4
i=151   label= 9 predict= 8
i=233   label= 8 predict= 7
i=241   label= 9 predict= 8
i=245   label= 3 predict= 5
i=247   label= 4 predict= 2
i=259   label= 6 predict= 0
i=320   label= 9 predict= 1
i=340   label= 5 predict= 3
i=381   label= 3 predict= 7
i=386   label= 6 predict= 5
sess.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1452376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++初阶(十一) list

一、list的介绍及使用 1.1 list的介绍 list的文档介绍 1. list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。 2. list的底层是双向链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&#xff0c;在节点…

爱上JVM——常见问题(一):JVM组成

1 JVM组成 1.1 JVM由那些部分组成&#xff0c;运行流程是什么&#xff1f; 难易程度&#xff1a;☆☆☆ 出现频率&#xff1a;☆☆☆☆ JVM是什么 Java Virtual Machine Java程序的运行环境&#xff08;java二进制字节码的运行环境&#xff09; 好处&#xff1a; 一次编写&…

巨抽象的前端vue3

根据实践证明&#xff0c;越是简单的问题&#xff0c;越容易造成大bug 一个自定义组件的路径就废了我老半天了 各种查询&#xff0c;各种百度&#xff0c;各种问&#xff0c;结果规规矩矩去导入组件路径&#xff0c;成了&#xff01; 错误代码&#xff1a; <script setu…

canal监听binlog记录业务数据的变更;canalAdmin对instance做web配置

概述 平时在开发中会通过logback打印一些开发日志&#xff0c;有时也会需要记录一些业务日志&#xff0c;简单的就直接用log记录一下&#xff0c;但是系统中需要记录日志的地方越来越多时&#xff0c;不能每个地方都写一套log记录&#xff1b; 由于平常用的大多都是mysql&…

c语言遍历文件夹中的文件

文件目录如下&#xff0c;文件夹里还有一些txt文件未展示出来。 使用递归实现&#xff0c;深度优先遍历文件夹中的文件。 代码如下&#xff0c;用了一点C的语法。 #include <io.h> #include <iostream> using namespace std;#define MAX_PATH_LENGTH 100int Tr…

创新技巧|迁移到 Google Analytics 4 时如何保存历史 Universal Analytics 数据

Google Universal Analytics 从 2023 年 7 月起停止收集数据&#xff08;除了付费 GA360 之外&#xff09;。它被Google Analytics 4取代。为此&#xff0c;不少用户疑惑&#xff1a;是否可以将累积&#xff08;历史&#xff09;数据从 Google Analytics Universal 传输到 Goog…

@ControllerAdvice 的介绍及三种用法

ControllerAdvice 的介绍及三种用法 浅析ControllerAdvice 首先&#xff0c;ControllerAdvice本质上是一个Component&#xff0c;因此也会被当成组建扫描&#xff0c;一视同仁&#xff0c;扫扫扫。 然后&#xff0c;我们来看一下此类的注释&#xff1a; 这个类是为那些声明了&…

智胜未来,新时代IT技术人风口攻略-第四版(弃稿)

文章目录 前言鸿蒙生态科普调研人员画像高校助力鸿蒙高校鸿蒙课程开设占比教研力量并非唯一原因 企业布局规划全盘接纳仍需一段时间企业对鸿蒙的一些诉求 机构入场红利机构鸿蒙课程开设占比机构对鸿蒙的一些诉求 鸿蒙实际体验高校用户群体高度认同与影响体验企业用户群体未来可…

【数据分享】2001~2020年青藏高原植被净初级生产力数据集

各位同学们好&#xff0c;今天和大伙儿分享的是2001~2020年青藏高原植被净初级生产力数据集。如果大家有下载处理数据等方面的问题&#xff0c;您可以私信或评论。 朱军涛. (2022). 青藏高原植被净初级生产力数据集&#xff08;2001-2020&#xff09;. 国家青藏高原数据中心. …

基于STM32的老人心率监测系统

1. 系统设计 本次课题为基于STM32的老人心率监测系统&#xff0c;在此设计了如图2.1所示的系统结构框图&#xff0c;整个系统包括了MAX30102心率血氧检测模块&#xff0c;SIM800短信模块&#xff0c;液晶显示模块&#xff0c;按键&#xff0c;ESP8266无线通信模块以及主控制器s…

WIN11 WSL2 Ubuntu系统删除Docker镜像后磁盘空间未减少问题解决办法

因为 windows 中的 docker 使用虚拟磁盘&#xff08;VHDX&#xff09;来存储文件系统。 windows中&#xff0c;docker需在WSL2配置下才可使用。 &#xff08;WSL是windows推出的可让开发人员不需要安装虚拟机(vmware,virtbox)或者设置双系统启动就可以原生支持运行GNU/Linux的…

CSS 不同颜色的小圆角方块组成的旋转加载动画

<template><!-- 创建一个装载自定义旋转加载动画的容器 --><view class="spinner"><!-- 定义外部包裹容器,用于实现整体旋转动画 --><view class="outer"><!-- 定义四个内部小方块以形成十字形结构 --><view clas…

【初始消息队列】消息队列的各种类型

消息队列相关概念 什么是消息队列 MQ(message queue)&#xff0c;从字面意思上看&#xff0c;本质是个队列&#xff0c;FIFO 先入先出&#xff0c;只不过队列中存放的内容是 message 而已&#xff0c;还是一种跨进程的通信机制&#xff0c;用于上下游传递消息。在互联网架构中…

多线程的基本原理学习

由一个问题引发的思考 线程的合理使用能够提升程序的处理性能&#xff0c;主要有两个方面&#xff0c;第一个是能够利用多核cpu以及超线程技术来实现线程的并行执行&#xff1b;第二个是线程的异步化执行相比于同步执行来说&#xff0c;异步执行能够很好的优化程序的处理性能提…

智慧公厕管理软件

随着城市化的不断推进&#xff0c;城市公共设施逐渐完善&#xff0c;其中智慧公厕的建设也在不断提速。智慧公厕作为城市基础设施的重要组成部分&#xff0c;对城市卫生水平提升有着不可忽视的作用。而智慧公厕管理软件更是智慧公厕管理的基础&#xff0c;是公共厕所智慧化管理…

Netty Review - ByteBuf内存池源码解析

文章目录 Pre主要特点和工作原理类关系源码解析入口索引AbstractNioByteChannel.NioByteUnsafe#readallocHandle.allocate(allocator) 小结 Pre Netty Review - 直接内存的应用及源码分析 Netty Review - 底层零拷贝源码解析 主要特点和工作原理 ByteBuf 内存池是 Netty 中用…

二叉树入门算法题详解

二叉树入门题目详解 首先知道二叉树是什么&#xff1a; 代码随想录 (programmercarl.com) 了解后知道其实二叉树就是特殊的链表&#xff0c;只是每个根节点节点都与两个子节点相连而其实图也是特殊的链表&#xff0c;是很多节点互相连接&#xff1b;这样说只是便于理解和定义…

Sora技术报告——Video generation models as world simulators

文章目录 1. 视频生成模型&#xff0c;可以视为一个世界模拟器2. 技术内容2.1 将可视数据转换成patches2.2 视频压缩网络2.3 Spacetime Latent Patches2.4 Scaling transformers 用于视频生成2.5 可变的持续时间&#xff0c;分辨率&#xff0c;宽高比2.6 抽样的灵活性2.7 改进框…

软件工程师,OpenAI Sora驾到,快来围观

概述 近期&#xff0c;OpenAI在其官方网站上公布了Sora文生视频模型的详细信息&#xff0c;展示了其令人印象深刻的能力&#xff0c;包括根据文本输入快速生成长达一分钟的高清视频。Sora的强大之处在于其能够根据文本描述&#xff0c;生成长达60秒的视频&#xff0c;其中包含&…

【教学类-19-09】20240214《ABAB式-规律黏贴18格-手工纸15*15CM-一页3种图案,AB满,纵向、无边框》(中班)

背景需求 利用15*15CM手工纸制作AB色块手环&#xff08;手工纸自带色彩&#xff09;&#xff0c;一页3个图案&#xff0c;2条为一组&#xff0c;黏贴成一个手环 素材准备 代码展示 # # 作者&#xff1a;阿夏 # 时间&#xff1a;2024年2月14日 # 名称&#xff1a;正方形数字卡…