【深度学习框架TensorFlow】使用TensorFlow框架构建全连接的神经网络,实现手写数字识别

news2025/1/20 10:56:23

文章目录

  • 一.TensorFlow
    • 1.1 内容介绍
  • 二.开始实验
    • 2.1TensorFlow的基本使用
    • 2.2基于全连接神经网络的手写数字识别
    • 2.3 结论

一.TensorFlow

使用深度学习框架TensorFlow。
目标:

1.了解TensorFlow的基本用法;
2.学习使用TensorFlow构建全连接的神经网络,实现手写数字识别,
3.学习使用TensorFlow构建CNN网络,实现手写数字识别程序,
4.比较两种网络结构的识别精度,以进一步了解深度学习和TensorFlow。

1.1 内容介绍

TensorFlow是一个开源软件库,使用数据流图进行数值计算。

(Operation)表示图中的数学运算,而图中的边表示节点之间互连的多维数据数组,即张量。
Tensorflow作为最流行的深度学习框架之一,具有非常好的性能和强大的功能。它广泛应用于各个领域。

自2015年TensorFlow诞生以来,已有三个主要版本。
TensorFlow更适合大规模部署,特别是跨平台和嵌入式部署

二.开始实验

1.为了消除 numpy 版本过高造成的输出中有太多WARING 警告日志,此处先降低 numpy 版本

!pip install -U numpy==1.15.0

运行结果如下:

image-20221226183624191

2.版本检查,看一下刚刚降低 numpy 版本的版本号。

import numpy as np
print(np.__version__)

image-20221226183713331

2.1TensorFlow的基本使用

1.打印出 Hello tensorflow

import tensorflow as tf
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()      # 建立一个session
print (sess.run(hello))  # 通过session里面的run来运行结果
sess.close()

运行结果如下:

image-20221226184546711

2.基本运算

a = tf.constant(3)   # 定义常量3
b = tf.constant(4)   # 定义常量4
 
with tf.Session() as sess:  # 建立session
    print ("相加: %i" % sess.run(a+b))  # 计算输出两个变量相加的值
    print( "相乘: %i" % sess.run(a*b))  # 计算输出两个变量相乘的值

运行结果如下:

image-20221226184610059

3.定义变量

var1 = tf.Variable(10.0 , name="varname")     
var2 = tf.Variable(11.0 , name="varname")     
var3 = tf.Variable(12.0 )     
var4 = tf.Variable(13.0 )   
with tf.variable_scope("test1" ): 
    var5 = tf.get_variable("varname",shape=[2],dtype=tf.float32)   
    
with tf.variable_scope("test2"):
    var6 = tf.get_variable("varname",shape=[2],dtype=tf.float32)  

print("var1:",var1.name)   
print("var2:",var2.name)  
print("var3:",var3.name)   
print("var4:",var4.name)   
print("var5:",var5.name)  
print("var6:",var6.name)  

image-20221226184641364

2.2基于全连接神经网络的手写数字识别

1.从OBS公共桶中下载MNIST数据集

import os
import moxing as mox

if not os.path.isdir("./MNIST_data"):
    mox.file.copy_parallel("obs://modelarts-labs-bj4/course/hwc_edu/python_module_framework/datasets/tensorflow_data/MNIST_data/","./MNIST_data")

2.导入数据集,使用TensorFlow可以直接进行本地数据加载

设置tensorflow日志级别 过滤WARNING

import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)  
from tensorflow.examples.tutorials.mnist import input_data   

下载到本地的文件夹

data_folder="./MNIST_data"  

导入已经下载好的数据集,如果数据集不存在,会自动在线下载,可能比较耗时。

mnist = input_data.read_data_sets(data_folder, one_hot = True)  

运行结果如下:

image-20221226184849704

3.查看数据集的相关信息

# 训练数据集
print(mnist.train.images.shape, mnist.train.labels.shape)

# 测试数据集
print(mnist.test.images.shape, mnist.test.labels.shape)

# 验证数据集
print(mnist.validation.images.shape, mnist.validation.labels.shape)

运行结果如下:

image-20221226184930167

4.展示部分加载后的数据

import matplotlib.pyplot as plt
%matplotlib inline

plt.figure()
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.imshow(mnist.train.images[i].reshape((28,28)))
plt.show()

运行结果如下:

image-20221226185006517

5.下载测试图片

if not os.path.exists("./num.png"):
    mox.file.copy("obs://modelarts-labs-bj4/course/hwc_edu/python_module_framework/datasets/tensorflow_data/num.png", "./num.png")

6.处理测试图片为网络支持的输入格式

import numpy as np
import cv2

def make_label(label_num):
    label = np.zeros((1,10),dtype='float32')
    label[:,label_num] = 1.0
    return label

label_test = make_label(3)

img_path = "./num.png"  # 图片路径
img_file=cv2.imread(img_path,0)
img_file=cv2.resize(img_file,(28,28))
plt.imshow(img_file,'gray')
plt.show()

data_test = img_file
data_test = np.float32(data_test.reshape(1, 28*28))
print(data_test.shape)

运行结果如下:

image-20221226185103198

7.设置网络中会用到的超参数。

# 参数
learning_rate = 0.1
num_steps = 600
batch_size = 128
display_step = 100

# 神经网络的参数
n_hidden_1 = 256  # 第1层神经元数
n_hidden_2 = 256  # 第2层神经元数
num_input = 784   # MNIST数据输入(img形状:28 * 28)
num_classes = 10  # MNIST总类(0-9位)

8.全连接网络

# 计算图的input
X = tf.placeholder("float", [None, num_input])
Y = tf.placeholder("float", [None, num_classes])

# 隐藏层1
W1 = tf.Variable(tf.truncated_normal([num_input, n_hidden_1],stddev=0.1))
B1 = tf.Variable(tf.zeros([n_hidden_1]))
hidden1 = tf.nn.relu(tf.matmul(X,W1) + B1)

# 隐藏层2
W2 = tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2],stddev=0.1))
B2 = tf.Variable(tf.zeros([n_hidden_2]))
hidden2 = tf.nn.relu(tf.matmul(hidden1,W2) + B2)

#输出层
W3 = tf.Variable(tf.zeros([n_hidden_2,num_classes]))
B3 = tf.Variable(tf.zeros([num_classes]))
logits = tf.matmul(hidden2, W3) + B3

#softmax
y = tf.nn.softmax(logits)

9.优化器和损失函数

损失函数

loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=Y))

优化器:Adam

train_op = tf.train.AdagradOptimizer(learning_rate).minimize(loss_op)

模型评估指标和初始化变量

correct_pred = tf.equal(tf.argmax(y, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
init = tf.global_variables_initializer()

10.训练并验证

with tf.Session() as sess:
    sess.run(init)

    for step in range(1, num_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # 运行优化器,反向传播
        sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
        if step % display_step == 0 or step == 1:
            # 计算loss和acc
            loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x, Y: batch_y})
            print("Step " + str(step) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.3f}".format(acc))

    print("Optimization Finished!")

    # 在MNIST test images上验证效果
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: mnist.test.images,
                                      Y: mnist.test.labels}))
    
    # 对自己手写的数字进行识别
    test_acc,test_value = sess.run([accuracy,y], feed_dict={X:data_test, Y:label_test})
    # 设置numpy矩阵显示3位小数
    np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
    
    print(test_value)
    print("AI判断的数字是{}".format(list(test_value[0]).index(test_value[0].max())))

运行结果如下:

image-20221226185234379

2.3 结论

可以看到模型在测试集的精度为96%,并在我们给的手写数字上识别正确。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/116495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

双周赛(三)

T1: 如果你仍然再用二指禅打字,那我建议你重新学习打字,这样你打字会更快,感觉更舒适和愉快。 有很多网站教授正确的打字。下图描述了基本原理: 用同一手指按压颜色相同的键。黄色键需要用小指按压,蓝色的用无名指&a…

最新 iOS 更新后 iPhone 联系人和数据丢失/丢失

我两天前做了最新的更新,现在我有丢失的联系人。帮助!! 许多 iPhone 用户犹豫是否将他们的设备升级到最新的 iOS 系统有一个(也许是几个)充分的理由。每次 iOS 发布后,总会有新功能震撼我们的世界。但是&am…

Unity 小积累

** Unity 学习小积累 ** 1.FindObjectsOfType和FindObjectOfType 前者返回所有个体(集合) 后者返回第一个个体 (单个) 2.uinty打包问题 打包webgl遇到了 实际上和py没有关系 c盘不够了 单纯 3.Unity 默认下载位置 坑 1.Uni…

【css样式】页面实现一侧固定一侧滚动的效果,到底部后一起滚动

文章目录position的定位类型position的定位类型 static:默认值,没有定位,遵循正常的文档流 fixed:固定定位,元素的位置是相对于浏览器窗口 relative:相对定位,相对于其正常的位置,移…

BOSS直聘自动投简历的实现过程

这两年疫情,公司业务越来越差,必须得准备后路了,每天睡前都会在直聘上打一遍招呼,一直到打哈欠有睡意为止...,这样持续了一周,发现很难坚持,身为一名资深蜘蛛侠,怎么能这样下去呢?于…

3D模型的生成式AI

生成式 AI 席卷了 2022 年,我们最近决定 Physna 不应错过这个热点。 因此,尽管生成 AI 并不是我们的商业模式—Physna 是一家 3D 搜索和分析公司,专注于 AR/VR 和制造中的工程和设计应用—我们还是决定为 3D 模型和场景生成 AI 构建一个非常基…

Node.js——初识Node.js与内置模块(一)

1.初识 Node.js 1.1 浏览器中的 JavaScript运行环境 1.浏览器中的 JavaScript 的组成部分 2.为什么 JavaScript 可以在浏览器中被执行 3.为什么 JavaScript 可以操作 DOM 和 BOM 4.浏览器中的 JavaScript 运行环境 Javascript可以借助node,js进行后端开发 1.2 Node.js 简介 …

virtio前端驱动通知机制分析

virtio前端驱动通知机制分析 virtio 前后端主要通过PCI配置空间的寄存器来完成通信,I/O 请求的数据地址存放于 vring 中,并通过共享vring这个区域来实现 I/O 请求数据的共享。 由上图可知,虚拟机与主机之间交互用到了两个结构体:p…

智能网联汽车行业发展

智能网 联汽车信息安全发展趋势 智能网联汽车行业发展 根据工信部发布的《国家车联网产业标准体系建设 指南(智能网联汽车)》的定义,智能网联汽车是指搭载先进的车载传感器 、控制器、执行器等装置,并融合现代通信与网络技术&a…

明道云联合思迈特打造会员管理应用可视化联合解决方案

背景介绍 明道云在协助企业数字化转型过程中,发现客户对利用业务数据形成企业级报表和数据可视化大屏的需求十分强烈。为了满足这种需求,企业通常需要成立专门的数据分析团队,但这需要巨大的人力和财力投入,时间周期也较长。为了…

信息数智化招采系统源码——信息数智化招采系统

信息数智化招采系统 服务框架:Spring Cloud、Spring Boot2、Mybatis、OAuth2、Security 前端架构:VUE、Uniapp、Layui、Bootstrap、H5、CSS3 涉及技术:Eureka、Config、Zuul、OAuth2、Security、OSS、Turbine、Zipkin、Feign、Monitor、Stre…

React 学习笔记总结(五)

文章目录1. React 嵌套路由(多级路由)2. params参数 与 query参数3. React路由组件 传递params参数数据4. React路由组件 传递search参数5. React路由组件 传递search参数6. React路由组件 特殊情况: 刷新页面7. React路由 的 push 和 replace8. React的 编程式路由9. React路由…

Transformer架构:位置编码

2017年,Google等人提出了Vaswani提出了一种新颖的纯注意力序列到序列架构,闻名学术界与工业界的Transformer架构横空出世,它的可并行化训练能力和优越的性能称为自然语言处理领域和计算机视觉领域研究人员的热门选择,本文将重点讨…

elasticsearch--Master选举

最近一直在学习elasticsear相关的东西,在这学习的过程中记录一下比较重要的学习内容。方便以后看的时候加深印象。 假如宕机的节点是Master节点 下面是Maste节点选举 的流程图 在findMaster的方法中每隔一段时间就会ping所有的节点,看看有没有哪个节…

java设计模式三

文章目录4) 创建IOC容器相关的类5) 自定义IOC容器测试6) 案例中使用到的设计模式7.2 剖析MyBatis框架中用到的经典设计模式7.2.1 MyBatis回顾7.2.1.1 MyBatis与ORM框架7.2.1.1 MyBatis的基础使用7.2.2 MyBatis中使用到的设计模式7.2.2.1 Builder模式7.2.2.2 工厂模式7.2.2.3 单…

基于Java开发(PC)小说自检测系统【100010061】

Java 语言与系统设计课程(小说自检测系统) 一、目的与要求 ​ 自行下载最喜爱的小说 1 部。存到服务器中,格式自定。一般存储为文本文档。要求长篇小说,20 万字以上。举例说明:下载《三国演义》保存在服务器端。 ​…

Secure Boot功能简析

在数据中心中,云服务器由各种处理设备和外围组件(如加速器和存储设备)组成,这些设备通常都运行着固件。对云平台服务商来说,为保证这些设备的安全可靠,需要一种或多种机制,来保证这些设备在测试…

XYplorer使用教程

XYplorer使用教程 XYplorer是Windows的文件管理器。它具有标签式浏览,强大的文件搜索功能,多功能预览,高度可定制的界面,可选的双窗格以及一系列独特的方法,可以有效地自动执行频繁重复的任务。它快速,轻便…

【DCDC转换器】BUCK电路的演进

本文将是对BUCK型DCDC转换器的起步介绍,从BUCK电路模型的建立出发,可以对转换器原理有更清晰的认识。其次对三种不同类型开关电源转换器的原理进行计算,转换器的原理基本是类似的,相同的分析方法可以套用在其他模型上。最后引入了…

PHP基本语法(1)

前言:PHP是什么呢?PHP是一种后端开发用的语言,简单点说制作的网页分为静态和动态,静态网页用户体验性差,动态网页用户可以进行交互,而这种交互就需要PHP了。所以PHP他就是一门用于后端开发的语言。 注意&a…