MLP手写数字识别(2)-模型构建、训练与识别(tensorflow)

news2024/11/22 19:40:47

查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

1.MNIST的数据集下载与预处理

import tensorflow as tf
from keras.datasets import mnist
from keras.utils import to_categorical

(train_x,train_y),(test_x,test_y) = mnist.load_data()
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32) # 归一化
y_train,y_test = to_categorical(train_y),to_categorical(test_y) # onehot
print(X_train[:5])
print(y_train[:5])

2.搭建MLP模型

from keras import Sequential
from keras.layers import Flatten,Dense
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

3.模型训练

3.1 调用model.compile()函数对训练模型进行设置

model.compile(optimizer='adam',
			  loss='categorical_crossentropy',
              metrics=['accuracy'])
  • loss=‘categorical_crossentropy’: 损失函数设置为交叉熵损失函数,在深度学习中用交叉熵模式训练效果会比较好。
  • optimizer=‘adam’: 优化器设置为adam, 在深度学习中可以让训练更快收敛,并提高准确率。
  • metrics=[‘accuracy’]:评估模式设置为准确度评估模式。

loss参数常用的损失函数

  • binary_crossentropy: 亦称作对数损失,logloss
  • categorical_crossentropy: 交叉熵损失函数,亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为onehot形式
  • sparse_categorical_crossentropy:稀疏交叉熵损失函数。
  • kullback_leibler_divergence: 从预测值概率分布Q到真值概率分布P的信息增益,用以度量两个分布的差异
  • poisson: 即(pred-target*log(pred))的均值
  • cosine_proximity:预测值与真实标签的余弦距离平均值的相反数

优化器

  • SGD
  • RMSprop
  • Adagrad
  • Adadelta
  • Adam
  • Adamax
  • Nadam
  • TFOptimizer

评估模式

  • binary_accuracy: 对二分类问题,计算在所有预测值上的平均正确率
  • categorical_accuracy: 对多分类问题,计算在所有预测值上的平均正确率
  • sparse_categorical_accuracy:与categorical_accuracy相同,在对稀疏的目标值预测时有用
  • top_k_categorical_accuracy: 计算top-k正确率,当预测值的前K个值中存在目标类别即认为预测正确
  • sparse_top_k_categorical_accuracy: 与top_k_categorical_accuracy作用相同,但适用于稀疏情况

3.2 调用model.fit()配置训练参数,开始训练,并保存训练结果。

H = model.fit(x=X_train,
			  y=y_train,
			  validation_split=0.2,
			  epochs=20,
		      batch_size=128,
			  verbose=1)

在这里插入图片描述

4.显示模型准确率和误差

import matplotlib.pyplot as plt

def show_train(history,train,validation):
    plt.plot(history.epoch, history.history[train],label=train)
    plt.plot(history.epoch, history.history[validation],label=validation)
    plt.title(train)
    plt.legend()
    plt.show()
    
show_train(H,'loss','val_loss')
show_train(H,'accuracy','val_accuracy')

在这里插入图片描述

5.使用测试数据进行识别

import numpy as np
import matplotlib.pyplot as plt

def pred_plot_images_lables(images,labels,start_idx,num=5):
    # 预测
    res = model.predict(images[start_idx:start_idx+num])
    res = np.argmax(res,axis=1)

    # 画图
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

pred_plot_images_lables(X_test,test_y,0,5)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1641086.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

排序算法--直接选择排序

前提: 选择排序:选择排序(Selection sort)是一种比较简单的排序算法。它的算法思想是每一次从待排序的数据元素中选出最小(或最大)的一个元素,存放在序列的起始位置,直到全部待排序的数据元素排完。 话不多说,直接放图…

Xamarin.Android项目显示Properties

在 Visual Studio 2022 中,如果您需要调出“Properties”(属性)窗口,您可以使用以下几种方法: 快捷键: 您可以按 F4 快速打开当前选择项的“Properties”窗口。

postman中百度preview无法加载的解决方案

问题 在使用postman关联时,百度接口与天气接口已使用glb_city关联,但在百度接口发送请求时,发现preview无法加载 解决方案 1、进入百度 百度全球领先的中文搜索引擎、致力于让网民更便捷地获取信息,找到所求。百度超过千亿的中…

【云原生】Docker 实践(二):什么是 Docker 的镜像

【Docker 实践】系列共包含以下几篇文章: Docker 实践(一):在 Docker 中部署第一个应用Docker 实践(二):什么是 Docker 的镜像Docker 实践(三):使用 Dockerf…

在M1芯片安装鸿蒙闪退解决方法

在M1芯片安装鸿蒙闪退解决方法 前言下载鸿蒙系统安装完成后,在M1 Macos14上打开闪退解决办法接下来就是按照提示一步一步安装。 前言 重新安装macos系统后,再次下载鸿蒙开发软件,竟然发现打不开。 下载鸿蒙系统 下载地址:http…

eNSP-抓包解析HTTP、FTP、DNS协议

一、环境搭建 1.http服务器搭建 2.FTP服务器搭建 3.DNS服务器搭建 二、抓包 三、http协议 1.HTTP协议,建立在FTP协议之上 2.http请求 3.http响应 请求响应报文参考:https://it-chengzi.blog.csdn.net/article/details/113809803 4.浏览器开发者工具抓包…

堆栈打印跟踪Activity的启动过程(基于Android10.0.0-r41),framework修改,去除第三方app的倒计时页面

文章目录 堆栈打印跟踪Activity的启动过程(基于Android10.0.0-r41),framework修改,去除第三方app的倒计时页面1.打印异常堆栈2.去除第三方app的倒计时页面3.模拟点击事件跳过首页进入主页 堆栈打印跟踪Activity的启动过程(基于Android10.0.0-r41)&#x…

SpringCloud微服务项目创建流程

为了模拟微服务场景,学习中为了方便,先创建一个父工程,后续的工程都以这个工程为准,实用maven聚合和继承,统一管理子工程的版本和配置。 后续使用中只需要只有配置和版本需要自己规定之外没有其它区别。 微服务中分为…

MLP实现fashion_mnist数据集分类(2)-函数式API构建模型(tensorflow)

使用函数式API构建模型,使得模型可以处理多输入多输出。 1、查看tensorflow版本 import tensorflow as tfprint(Tensorflow Version:{}.format(tf.__version__)) print(tf.config.list_physical_devices())2、fashion_mnist数据集分类模型 2.1 使用Sequential构建…

4 Spring AOP

目录 AOP 简介 传统开发模式 先来看一个需求 解决方案 AOP 图示 Spring 启用 AspectJ 基于 xml 配置 创建 pom.xml 创建 UserService 借口和 UserServiceImpl实现类 创建 LogAdvice 日志通知 创建 log4j.properties 重点:创建 spring-context-xml.xml 配…

MyBatis:mybatis基础操作

MyBatis基础操作 新增 接口方法 Insert() insert();删除 接口方法 Delete() delete();Delete("delete from emp where id #{id}") public abstract void delete(Integer id) //如果只传了一个形参,括号内可以随意写修改 接口方法 Update() update();查询 接…

Jupyter Notebook魔术命令

Jupyter Notebook是一个基于网页的交互式笔记本,支持运行多种编程语言。 Jupyter Notebook 的本质式一个Web应用程序,便于创建和共享文学化程序文档,支持实现代码,数学方程,可视化和markdown。用途包括:数据…

MATLAB中自定义栅格数据地理坐标R,利用geotifwrite写入tif

场景描述: 有时候将nc格式的数据转成tiff,或者是将一个矩阵输出成带有地理坐标信息tiff数据时,常常涉及到空间参考的定义和geotiffwrite()函数。 问题描述: 以全球数据为例,今天发现在matlab中对矩阵进行显示后&…

【大模型学习】大模型相关概念

知识库 Embeding 嵌入,又称向量化、矢量化。 Prompt engineer 提示词工程 提示工程技巧 RAG 检索增强生成,提高文本的准确性和丰富性。 Fine tuning 微调,优化已有人工智能模型以适应特定任务的技术。 AI agent AI代理人&…

华为机考入门python3--(19)牛客19- 简单错误记录

分类:字符串 知识点: 分割字符串 my_str.split(\\) 字符串只保留最后16位字符 my_str[-16:] 列表可以作为队列、栈 添加元素到第一个位置 my_list.insert(0, elem) 增加元素到最后一个位置 my_list.append(elem) 删除第一个 my_list.pop(0)…

Python中的数据可视化:阶梯图matplotlib.pyplot.step()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 Python中的数据可视化: 阶梯图 matplotlib.pyplot.step() [太阳]选择题 matplotlib.pyplot.step()的功能是? import matplotlib.pyplot as plt import numpy as…

基于Springboot的旅游管理系统(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的旅游管理系统(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构&…

[1702]java旅游资源网上填报系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java旅游资源网上填报系统是一套完善的java web信息管理系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发,数据库为Mysql…

机器视觉系统-条形光源安装位置计算

使用条形光对反光材质物体打光时,常常出现强烈的光斑反射,影响图像处理。如果不想图像中出现光源的光斑,可以通过计 算得出条形光源的安装范围。 检则PCB板上的二维码字符,使用两个条形光打光的效果图 以及等效模型: …

CSS 鼠标经过放大元素 不影响其他元素

效果 .item:hover{transform: scale(1.1); /* 鼠标悬停时将元素放大 1.1 倍 */ }.item{transition: transform 0.3s ease; /* 添加过渡效果,使过渡更加平滑 */ }