【Tensorflow学习三】神经网络搭建八股“六步法”编写手写数字识别训练模型

news2024/11/15 10:52:38

神经网络搭建八股“六步法”编写手写数字识别训练模型

  • Sequential用法
  • model.compile(optimizer=优化器,loss=损失函数,metrics=["准确率"])
  • model.fit
  • model.summery
  • 六步法搭建鸢尾花分类网络
  • class搭建具有非顺序网络结构
  • MNIST数据集
  • Fashion MNIST数据集

Tensorflow API:tf.keras搭建网络八股

六步法

1.import (import相关模块,比如import tensorflow as tf)
2.train test (喂入网络的训练集和测试集,指定训练集的输入特征x_train和训练集的标签y_train ,还可以指定测试集的输入特征x_test和测试集的标签y_test)
3.model=tf.keras.models.Sequential  (搭建网络结构,逐层描述每层网络)
4.model.compile (配置训练方法,告知训练时选择哪种优化器、损失函数、评测指标)
5.model.fit (执行训练过程,告知训练集和测试集的输入特征和标签,告知每个batch是多少、要迭代多少次数据集)
6.model.summary (打印网络的结构和参数统计)

Sequential用法

可以认为Sequential()是个容器,在这个容器中封装了一个神经网络结构

Sequential要描述从输入层到输出层每一层的网络结构

例如:

1.拉直层:tf.keras.layers.Flatten()

2.全连接层:tf.keras.layers.Dense(神经元个数,activation='激活函数',kernel_regularizer=哪种正则化)

activation可选:relu、softmax、sigmoid、tanh
kernel_regularizer可选:tf.keras.regularizers.l1()、tf.keras.regularizers.l2()

3.卷积层:tf.keras.Conv2D(filter=卷积核个数,kernel_size=卷积核尺寸,strides=卷积步长,padding="valid" or "same")

4.LSTM层:tf.keras.layers.LSTM()

model.compile(optimizer=优化器,loss=损失函数,metrics=[“准确率”])

optimizer可选:

#优化器,可以是字符串形式的名字,还可以是函数形式(函数形式可以设置学习率、动量等超参数)
"sgd" or tf.keras.optimizers.SGD(lr=学习率,momentum=动量参数)
"adagrad" or tf.keras.optimizers.Adagrad(lr=学习率)
"adadelta" or tf.keras.optimizers.Adadelta(lr=学习率)
"adam" or tf.keras.optimizers.Adam(lr=学习率,beta_1=0.9,beta_2=0.999)

loss可选:

#优损失函数,可以是字符串形式的名字,还可以是函数形式
"mse" or tf.keras.losses.MeanSquaredError()
"sparse_categorical_crossentropy" or tf.keras.SparseCategoricalCrossentropy(from_logits=False)#from_logits=False 询问是否是原始输出(指未经过softmax概率输出的,经过False,未经过True)

Metrics可选:

"accuracy":y_和y都是值
"categorical_accuracy":y与y_都是独热码
"sparse_categorical":y_是值,y是独热码

model.fit

modelfit(训练集输入特征,训练集的标签,
         batch_size= , epochs= ,
         validation_data=(测试集的输入特征,测试集的标签)
         vaildation_split=从训练集划分多少比例给测试集
         vaildation_freq=多少epoch测试一次)
#validation_data与vaildation_split二选一

model.summery

model.summery可以打印网络的结构和参数统计

以鸢尾花分类的网络为例

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Si0hqkbk-1670672738651)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221206171521697.png)]

六步法搭建鸢尾花分类网络

import tensorflow as tf
from sklearn import datasets
import numpy as np

x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])

model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

model.summary()

class搭建具有非顺序网络结构

#想要搭建非顺序的网络结构,可以用class
1.import 
2.train、test
3.class MyModel(model) model=MyModel
4.model.compile
5.model.fit
6.model.sunmmary

class类来封装一个网络结构

class MyModel(Model)

class MyModel(Model):#MyModel为神经网络的名字,继承了Tensorflow的Model类
    def __init__(self):
        super(MyModel,self).__init__()
        定义网络结构模块
    def call(self,x):
        调用网络结构模块,实现前向传播
        return y

model=MyModel()
###
__init__() 定义所需的网络结构块
call() 写出前向传播
###

例子:

class IrisModel(Model):
    def __init__(self):
        supper(IrisModel,self).__init__()
        self.d1=Dens(3)
    
    def call(self,x):
        y=self.d1(x)
        return y

model=IrisModel()

用类实网络结构实现鸢尾花分类的代码:

import tensorflow as tf
from tensorflow.keras.layers import Dense#添加的部分
from tensorflow.keras import Model#添加的部分
import numpy as np
from sklearn import datasets

x=datasets.load_iris().data
y=datasets.load_iris().target

np.random.seed(116)
np.random.shuffle(x)
np.random.seed(116)
np.random.shuffle(y)

class IrisModel(Model):
    def __init__(self):
        super().__init__()
        self.d1=Dense(3,activation='sigmoid',kernel_regularizer=tf.keras.regularizers.l2())

    def call(self,x):
        y=self.d1(x)
        return y
model=IrisModel()
# model=tf.keras.models.Sequential([tf.keras.layers.Dense(3,activation="softmax",kernel_regularizer=tf.keras.regularizers.l2())])

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])

model.fit(x,y,batch_size=32,epochs=500,validation_split=0.2,validation_freq=20)

model.summary()

结果为

Epoch 500/500
4/4 [==============================] - 0s 11ms/step
- loss: 0.4527 
- sparse_categorical_accuracy: 0.8500 
- val_loss: 1.0348 
- val_sparse_categorical_accuracy: 0.5333

Model: "iris_model"
____________________________________________________________
 Layer (type)                Output Shape              Param 
============================================================
 dense (Dense)               multiple                  15        
                                                             
============================================================
Total params: 15
Trainable params: 15
Non-trainable params: 0
____________________________________________________________

其中:

loss:训练集loss
val_loss:测试集loss
sparse_categorical_accuracy:训练集准确率
val_sparse_categorical_accuracy:测试集准确率

MNIST数据集

MNIST数据集
提供6w张 28*28像素的0~9手写数字图片和标签,用于训练
提供1w张 28*28像素的0~9手写数字图片和标签,用于测试

导入数据集:
minist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
作为输入特征,输入神经网络时,将数据拉伸为一维
tf.keras.layers.Flatten()
import tensorflow as tf
from matplotlib import pyplot as plt

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data(r"这里填mnist数据集路径")#把minist数据集先下载到电脑中,再导入,直接下载容易出错

# 可视化训练集输入特征的第一个元素
plt.imshow(x_train[0], cmap='gray')  # 绘制灰度图
plt.show()

# 打印出训练集输入特征的第一个元素
print("x_train[0]:\n", x_train[0])
# 打印出训练集标签的第一个元素
print("y_train[0]:\n", y_train[0])

# 打印出整个训练集输入特征形状
print("x_train.shape:\n", x_train.shape)
# 打印出整个训练集标签的形状
print("y_train.shape:\n", y_train.shape)
# 打印出整个测试集输入特征的形状
print("x_test.shape:\n", x_test.shape)
# 打印出整个测试集标签的形状
print("y_test.shape:\n", y_test.shape)

用mnist数据集的训练代码

import tensorflow as tf

mnist=tf.keras.datasets.mnist

(x_train,y_train),(x_test,y_test)=mnist.load_data(r"这里填mnist数据集路径")

x_train,x_test=x_train/255.0,x_test/255.0#归一化到[0,1]

model=tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation='softmax')
])

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=["sparse_categorical_accuracy"])

model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1)#执行训练过程

model.summary()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B827fFda-1670672738653)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221210192009223.png)]

用类实现手写字识别

import tensorflow as tf
from tensorflow.keras.layers import Dense,Flatten
from tensorflow.keras import Model
mnist=tf.keras.datasets.mnist

(x_train,y_train),(x_test,y_test)=mnist.load_data(r"C:\Users\98306\Desktop\Tensorflow\中国大学MOOCTF笔记2.1共享给所有学习者\class3\mnist.npz")

x_train,x_test=x_train/255.0,x_test/255.0 #归一化到[0,1]

class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.flatten=Flatten()
        self.d1=Dense(128,activation='relu')
        self.d2=Dense(10,activation='softmax')

    def call(self,x):
        x=self.flatten(x)
        x=self.d1(x)
        y=self.d2(x)
        return y

model=MnistModel()
# model=tf.keras.models.Sequential([
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(128,activation='relu'),
#     tf.keras.layers.Dense(10,activation='softmax')
# ])

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=["sparse_categorical_accuracy"])

model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1)

model.summary()

Fashion MNIST数据集

Fashion MNIST数据集
提供6w张 28*28像素的0~9手写衣服裤子等图片和标签,用于训练
提供1w张 28*28像素的0~9手写衣服裤子等图片和标签,用于测试
一共十个分类

导入数据集:
minist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data()
作为输入特征,输入神经网络时,将数据拉伸为一维
tf.keras.layers.Flatten()

Fashion MNIST数据集没法直接代码下载的话可以参考这篇博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/78894.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java计算机毕业设计基于安卓Android的掌上酒店预订APP

项目介绍 网络的广泛应用给生活带来了十分的便利。所以把掌上酒店预订与现在网络相结合,利用java技术建设掌上酒店预订APP,实现掌上酒店预订的信息化。则对于进一步提高掌上酒店预订发展,丰富掌上酒店预订经验能起到不少的促进作用。 掌上酒店预订APP能够通过互联网得到广泛的…

基于风能转换系统的非线性优化跟踪控制(Matlab代码实现)

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清…

一、CDD在诊断开发中的作用

本专栏将由浅入深的展开诊断实际开发与测试的数据库编辑,包含大量实际开发过程中的步骤、使用技巧与少量对Autosar标准的解读。希望能对大家有所帮助,与大家共同成长,早日成为一名车载诊断、通信全栈工程师。 本文介绍CDD在诊断开发中的作用,欢迎各位朋友订阅、评论,可以提…

如何评价模型的好坏?

回归: MSE(均方误差)—— 判定方法:值越小越好(真实值-预测值,平方之后求和平均)RMSE(均根方误差)—— 判定方法:值越小越好(MSE开根号&#xff…

Dijkstra最短路径算法

参考:(3条消息) Dijkstra算法图文详解_一叶执念的博客-CSDN博客_迪杰斯特拉算法 如图,假设图中共有n条路径(如D-C-E),根据路径长度进行小到大排序。 1、起点到达某终点的距离是无穷符号,表示该起点还需要借…

27岁到来之际,我在阿里实现了年薪40W+的小目标

顺着大佬的思路,我分析了自己的实际水平和状况: 1、技术不精不成体系:技术能力浮于表面,对底层逻辑和架构不了解,也不知道如何系统化进行学习; 2、遇到职场瓶颈期:站在3年职场的分水岭上,没有…

Linux网络原理及编程(8)——第十八节 数据链路层

目录 1、MAC地址 2、MAC帧 3、MAC帧协议 4、MTU 5、ARP请求和应答 各位好,博主新建了个公众号《自学编程村》,拉到底部即可看到,有情趣可以关注看看哈哈,关注后还可以加博主wx呦~~~(公众号拉到底部就能看到呦~~&a…

基于java+springmvc+mybatis+jsp+mysql的新冠肺炎疫苗接种管理系统

项目介绍 新冠疫苗接种管理系统,在网站首页可以查看首页,疫苗信息,疫苗资讯 ,个人中心,后台管理,在线客服等内容,并进行详细操作。管理员登录进入系统可以查看首页,个人中心&#x…

vue自定义keepalive组件的问题解析

前一阵来了一个新的需求,要在vue项目中实现一个多开tab页面的功能,本来心想,这不简单嘛就是一个增加按钮重定向吗?(当然如果这么简单我就不写这个文章了)。很快写完,提交测试。测试大哥很快就提…

一份奇奇怪怪的地图设计书

地图设计书 地图设计是通过研究实验制定新编地图的内容、表现形式及其生产工艺程序的工作,是地图制图学各种活动的中心,贯穿整个地图制图过程。本设计选择了福建省龙岩市作为研究区域,并结合相应区域的土地利用类型、水系、道路等数据&#…

儿童剧本杀行业是好生意吗?剧本杀门店管理系统

红楼梦、西游记、水浒传、三国演义是中国四大名著,几乎每个中国人上到70岁老人,下到十岁小学生都知道,同时还有花木兰、包青天、八仙过海等故事也都耳濡目染,小说描述的淋漓尽致,影视剧老戏骨们将每个角色刻画的深入人…

spring-aop源码分析(3)完结_执行流程分析

本文详细介绍Spring AOP的执行阶段流程。 Cglib代理的代理拦截逻辑在DynamicAdvisedInterceptor中,JDK代理的拦截逻辑在JdkDynamicAopProxy中,本文将从这两个类入手分析Spring AOP的执行阶段流程。 DynamicAdvisedInterceptor private static class D…

Modbus数据采集方案

目录 目标 Modbus协议简介 配置界面以及实例 概述 modbus协议应该是工业行业应用最广泛的协议,由于其协议简单、通讯标准、扩展性强的特点,被各个行业大量的应用。作为通讯网关机来说,设计一个便捷易懂的配置方式显得尤其重要。本方案基于…

多模态中的指令控制(InstructPix2Pix,SayCan)

InstructPix2Pix: Learning to Follow Image Editing Instructions 图像的语言指令生成。目的是遵循人工指令去编辑图像,即给定输入图像和一个如何编辑它的文本指令,模型尝试遵循这些指令来编辑图像。 这份论文与现有基于文本的图像编辑工作们最大的不同…

【JVM】方法区与永久代、元空间之间的关系

方法区与永久代、元空间之间的关系 方法区是JVM规范中定义的一块内存区域,用来存储类元数据、方法字节码、即时编译器需要的信息等 永久代是Hotspot虚拟机对JVM规范的实现(1.8之前) 元空间是Hotspot虚拟机对JVM规范的实现(1.8以后),使用本地…

java基于springboot高校学报论文在线投稿系统-计算机毕业设计

项目介绍 在新发展的时代,众多的软件被开发出来,给用户带来了很大的选择余地,而且人们越来越追求更个性的需求。在这种时代背景下,高校只能以工作人员为导向,以稿件的持续创新作为高校最重要的竞争手段。 系统采用了J…

Java AQS

AQS 是什么 AQS 的全称为 AbstractQueuedSynchronizer,翻译过来的意思就是抽象队列同步器,这个类在 java.util.concurrent.locks 包下面Java 中的大部分同步类(Lock、Semaphore、ReentrantLock等) 都是基于 AQS 实现的AQS 是一种提供了原子式管理同步状…

SpringBoot - 整合WebSocket时@ServerEndpoint修饰的类属性注入为null问题

SpringBoot - 整合WebSocket时ServerEndpoint修饰的类属性注入为null问题前言一. 问题复现1.1 原因分析二. 问题解决前言 最近在做一个直播弹幕系统,前期准备先用WebSocket来试试水。我们都知道,使用WebSocket只需要给对应的类加上注解ServerEndpoint即…

Linux之定时任务--crontab命令解析学习

Corntab定时任务学习 一、crond服务 在学习crontab,命令之前,我觉得有必要学习了解一下crond服务,因为要在linux系统下使用crontab命令需要crond的支持。Crond是Linux下要用来周期执行某种任务或者等待处理某些事件的一个守护进程。和Windo…

项目——员工管理系统

开发环境:vmware ubuntu18.04 实现功能:基本功能包括管理者和普通员工用户的登录,管理者拥有操作所有员工信息的最高权限,可以进行增删改 查等操作,普通用户仅拥有查看、修改个人部分信息的权限 具体功能详解&…