17- TensorFlow实现手写数字识别 (tensorflow系列) (项目十七)

news2024/9/29 19:22:08

项目要点

  • 模型创建: model = Sequential()
  • 添加卷积层: model.add(Dense(32, activation='relu', input_dim=100))  # 第一层需要 input_dim
  • 添加dropout: model.add(Dropout(0.2))
  • 添加第二次网络: model.add(Dense(512, activation='relu'))   # 除了first, 其他层不要输入shape
  • 添加输出层: model.add(Dense(num_classes, activation='softmax'))  # last 通常使用softmax
  • TensorFlow 中,使用 model.compile 方法来选择优化器和损失函数:
    • optimizer: 优化器: 主要有: tf.train.AdamOptimizer , tf.train.RMSPropOptimizer , or tf.train.GradientDescentOptimizer .

    • loss: 损失函数: 主要有:mean square error (mse, 回归), categorical_crossentropy (多分类) , and binary_crossentropy (二分类).

    • metrics: 算法的评估标准, 一般分类用accuracy.

  • model.fit(x_train, y_train, batch_size = 64, epochs = 20, validation_data = (x_test, y_test))    # 模型训练
  • score = model.evaluate(x_test, y_test, verbose=0)    两个返回值: [ 损失率 , 准确率 ]


1 实例演示Keras的使用 (手写数字识别)

1.1 导包

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import rmsprop_v2

1.2 导入数据

# 导入手写数字数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
'''(60000, 28, 28) (60000,) (10000, 28, 28) (10000,)'''
import matplotlib.pyplot as plt
plt.imshow(x_train[0], cmap = 'gray')

 1.3 数据初步处理

# 对数据进行初步处理
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape, 'train samples')  # (60000, 784) train samples
print(x_test.shape, 'test samples')    # (10000, 784) test samples

1.4 数据初步处理

  • 独热编码
import tensorflow
# 将标记结果转化为独热编码
num_classes = 10
y_train = tensorflow.keras.utils.to_categorical(y_train, num_classes)
y_test = tensorflow.keras.utils.to_categorical(y_test, num_classes)
y_train

  1.5 创建模型

# 创建顺序模型
model = Sequential()
# 添加第一层网络, 512个神经元, 激活函数为relu
model.add(Dense(512, activation='relu', input_shape=(784,)))
# 添加Dropout
model.add(Dropout(0.2))
# 第二层网络
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
# 输出层
model.add(Dense(num_classes, activation='softmax'))
# 打印神经网络参数情况
model.summary()

 1.6 模型训练

# 编译
model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

batch_size = 128
epochs = 20
# 训练并打印中间过程
history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, y_test))
# 计算预测数据的准确率
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])  # Test loss: 0.14742641150951385
print('Test accuracy:', score[1])   # Test accuracy: 0.9815000295639038

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/375275.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

这是一款值得开发人员认真研究的软件,数据库优化,应用服务器安全优化...

1.查询数据库死锁相关信息2.查看数据库的链接情况3.当前实例上的所有用户4.创建数据库独立密码5.查看数据库使用的端口号6.当前数据库设置的最大连接数7.当前数据库最大的理论可连接数8.当前数据库实例的连接数9.当前数据库连接数10.当前数据库连接超时设置11.当前sqlserver 超…

SMART系统—考试监控及阅卷模块的设计与开发

技术:Java、JSP等摘要:Smart在线考试评估系统(简称“SMART系统”),是基于Browser/Server(简称B/S)结构的数据库访问模式,采用Struts Spring Hibernate作为平台搭建的框架开发的一套新型智能的远程教育软件…

伺服三环控制深层原理解析

我们平时使用的工业伺服,通常是成套伺服,即驱动器和电机型号存在配对关系。 但有些时候,我们要用电机定转子和编码器制作非成套电机,这种时候,我们需要对驱动器进行各种设置才能驱动电机。 此篇文章将通过介绍伺服控制的三环控制原理入手来说明我们调试非成套伺服时需要…

2023年微信小程序获取手机号授权登录注册详细教程,包含服务端教程

前言 小程序中有很多地方都会用到用户的手机号,比如登陆注册,填写收货地址等等。有了这个组件可以快速获取微信绑定手机号码,无须用户填写。网上大多数教程还是往年的,而微信官方的api已做了修改。本篇文章将使用最新的方法获取手…

【unity学习记录】Canvas Group组件

💗 未来的游戏开发程序媛,现在的努力学习菜鸡 💦本专栏是我关于游戏开发的学习笔记 🈶本篇是unity的Canvas Group组件 Canvas Group画布组介绍详解1. Alpha2. Interactable3. Blocks Raycasts4. Ignore Parent Groups介绍 画布组…

用反射模拟IOC模拟getBean

IOC就是spring的核心思想之一:控制反转。这里不再赘述,看我的文章即可了解:spring基础思想IOC其次就是java的反射,反射机制是spring的重要实现核心,今天我看spring的三级缓存解决循坏引用的问题时,发现一个…

机器学习——模型评估

在学习得到的模型投放使用之前,通常需要对其进行性能评估。为此,需使用一个“测试集”(testing set)来测试模型对新样本的泛化能力,然后以测试集上的“测试误差( tootino error)作为泛化误差的近似。我们假设测试集是从样本真实分…

ShardingSphere水平、垂直分库、分表和公共表

目录一、ShardingSphere简介二、ShardingSphere-分库分表1、垂直拆分(1)垂直分库(2)垂直分表2、水平拆分(1)水平分库(2)水平分表三、水平分库操作1、创建数据库和表2、配置分片的规则…

中级嵌入式系统设计师2016下半年上午试题及答案解析

中级嵌入式系统设计师2016下半年上午试题 单项选择题 1、(1)用来区分在存储器中以二进制编码形式存放的指令和数据。 A. 指令周期的不同阶段 B. 指令和数据的寻址方式 C. 指令操作码的译码结果 D. 指令和数据所在的存储单元 2、计算机在一个指令周期的过程中,为从…

web服务器(1)

阻塞和非阻塞、同步和异步 网络IO阶段一:数据就绪 操作系统,tcp接受缓冲区 阻塞:调用IO方法的线程进入阻塞状态 非阻塞:不会改变线程的状态,通过返回值判断 网络IO阶段二:数据读写 应用程序 同步…

接口自动化框架---升级版(Pytest+request+Allure)

目录:导读 一、简单介绍 二、目录介绍 三、代码分析 写在最后 接口自动化是指模拟程序接口层面的自动化,由于接口不易变更,维护成本更小,所以深受各大公司的喜爱。 第一版入口:接口自动化框架(PytestrequestAllure…

[Android Studio] Android Studio使用keytool工具读取Debug 调试版数字证书以及release 发布版数字证书

🟧🟨🟩🟦🟪 Android Debug🟧🟨🟩🟦🟪 Topic 发布安卓学习过程中遇到问题解决过程,希望我的解决方案可以对小伙伴们有帮助。 📋笔记目…

学生宿舍管理系统

技术:Java、JSP等摘要:管理信息系统在现代社会已深入到各行各业,由于计算机技术的迅速发展和普及,信息管理系统MIS事实上已成为计算机管理信息系统,大学生宿舍管理系统就是一个典型的管理信息系统,它可以让宿舍管理工作…

【算法题】最大矩形面积,单调栈解法

力扣:84. 柱状图中最大的矩形 给定 n 个非负整数,用来表示柱状图中各个柱子的高度。每个柱子彼此相邻,且宽度为 1 。 求在该柱状图中,能够勾勒出来的矩形的最大面积。 题意很简单,翻译一下就是:求该图中…

模拟银行存取钱-课后程序(JAVA基础案例教程-黑马程序员编著-第八章-课后作业)

【案例8-3】 模拟银行存取钱 【案例介绍】 1.任务描述 在银行办理业务时,通常银行会开多个窗口,客户排队等候,窗口办理完业务,会呼叫下一个用户办理业务。本案例要求编写一个程序模拟银行存取钱业务办理。假如有两个用户在存取…

【Linux】-- POSIX信号量

目录 POSIX信号量 sem_init - 初始化信号量 sem_destroy - 销毁信号量 sem_wait - 等待信号量(P操作) 基于环形队列的生产消费模型 数据结构 - 环形结构 实现原理 POSIX信号量 #问:什么是信号量? 1. 共享资源 -> 任何一…

2. 驱动开发--驱动开发环境搭建

文章目录前言一、Linux中配置编译环境1.1 linux下安装软件的方法1.2 交叉编译工具链的安装1.2.1 测试是否安装成功1.3 设置环境变量1.3.1 将工具链导出到环境变量1.4 为工具链创建arm-linux-xxx符号链接二、 搭建运行开发环境2.1 tftp网络方式加载内核和设备树文件2.2 nfs网络方…

大事很妙,跨境电商用Reddit做营销做测评真的很有用

最近呢,东哥在和一个叫 jens 的海外社媒大佬聊天,聊起了Reddit,其实 Reddit 可是个不错的流量平台,里面有不少宝藏,跟我们国内的贴吧差不多啦。 作为美国热度排名前五的社交网站,流量如此不错的平台&#…

3、Improved Denoising Diffusion Probabilistic Models#

简介论文发现通过一些简单的修改,ddpm也可以在保持高样本质量的同时实现竞争对数可能性,反向扩散过程的学习方差允许以更少的正向传递数量级进行采样,而样本质量的差异可以忽略不计,这对于这些模型的实际部署非常重要。 github链接…

AOF:redis宕机,如何避免数据丢失

由于redis是基于内存的数据库,一旦宕机,数据就会丢失?如何解决? 目前,Redis 的持久化主要有两大机制,即 AOF(Append Only File)日志和 RDB(Redis DataBase) 快照。 AO…