18- TensorFlow实现CIFAR10分类 (tensorflow系列) (项目十八)

news2024/9/23 22:31:23

项目要点

  • 导入cifar图片集: (train_image, train_label), (test_image, test_label) = cifar.load_data()    # cifar = keras.datasets.cifar10
  • 图片归一化处理: train_image = train_image / 255
  • 定义模型: model = keras.Sequential()
    • 输入层: model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3))) 
    • 添加卷积层: model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    • Conv2D(filters, kernel_size, strides, padding, activation=‘relu’, input_shape)
      • filters: 过滤器数量
      • kernel_size: 指定(方形)卷积窗口的高和宽的数字
      • strides: 卷积步长, 默认为 1
      • padding: 卷积如何处理边缘。选项包括 ‘valid’ 和 ‘same’。默认为 ‘valid’
      • activation: 激活函数,通常设为 relu。如果未指定任何值,则不应用任何激活函数。强烈建议你向网络中的每个卷积层添加一个 ReLU 激活函数。
      • input_shape: 指定输入层的高度,宽度和深度的元组。当卷积层作为模型第一层时,必须提供此参数,否则不需要。
  • 添加BN层: model.add(layers.BatchNormalization())       # 1.最重要的作用是加快网络的训练收敛的速度; 2.控制梯度爆炸防止梯度消失; 3.防止过拟合
  • 添加池化层: model.add(layers.MaxPooling2D())     # 可以加快计算速度和防止过拟合作用
  • 添加dropout: model.add(layers.Dropout(0.25))      # 防止神经网络过拟合的手段。随机的拿掉网络中的部分神经元
  • 添加全局平均池化: model.add(layers.GlobalAveragePooling2D())  # 替代全连接层减少参数数量,减少计算量,减少过拟合
  • 添加输出层: model.add(layers.Dense(10, activation='softmax'))
  • 查看模型结构: model.summary()
  • 配置模型:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])
  • 训练模型: history = model.fit(train_image, train_label, epochs=30, batch_size=128)
  • 模型评估: model.evaluate(test_image, test_label)


1 数据集简介

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。

  • 下面这幅图就是列举了10各类,每一类展示了随机的10张图片:

 与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:

  • CIFAR-103 通道的彩色 RGB 图像,而 MNIST 是灰度图像
  • CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
  • 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。 直接的线性模型如 Softmax 在 CIFAR-10 上表现得很差

2 导包

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

cpu=tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpu)
print(tf.config.list_logical_devices())

3 数据导入

cifar = keras.datasets.cifar10
(train_image, train_label), (test_image, test_label) = cifar.load_data()
train_image.shape, test_image.shape   # ((50000, 32, 32, 3),(10000, 32, 32, 3))

# 归一化处理
train_image = train_image / 255
test_image = test_image / 255

4 建立模型

Conv2D 构建卷积层。用于从输入的高维数组中提取特征。卷积层的每个过滤器就是一个特征映射,用于提取某一个特征,过滤器的数量决定了卷积层输出特征个数,或者输出深度。因此,图片等高维数据每经过一个卷积层,深度都会增加,并且等于过滤器的数量。 

Conv2D(filters, kernel_size, strides, padding, activation=‘relu’, input_shape)

  • filters: 过滤器数量
  • kernel_size: 指定(方形)卷积窗口的高和宽的数字
  • strides: 卷积步长, 默认为 1
  • padding: 卷积如何处理边缘。选项包括 ‘valid’ 和 ‘same’。默认为 ‘valid’
  •  activation: 激活函数,通常设为 relu。如果未指定任何值,则不应用任何激活函数。强烈建议你向网络中的每个卷积层添加一个 ReLU 激活函数。
  • input_shape: 指定输入层的高度,宽度和深度的元组。当卷积层作为模型第一层时,必须提供此参数,否则不需要。

GlobalAveragePooling2D()作用:

全局平均池化 作用:如果要预测K个类别,在卷积特征抽取部分的最后一层卷积层,就会生成K个特征图,然后通过全局平均池化就可以得到 K个1×1的特征图,将这些1×1的特征图输入到softmax layer之后,每一个输出结果代表着这K个类别的概率(或置信度 confidence),起到取代全连接层的效果。
优点:

  • 和全连接层相比,使用全局平均池化技术,对于建立特征图和类别之间的关系,是一种更朴素的卷积结构选择。 # 替换全连接层
  • 全局平均池化层不需要参数,避免在该层产生过拟合
  • 全局平均池化对空间信息进行求和,对输入的空间变化的鲁棒性更强
# 定义神经网络
model = keras.Sequential()
model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D())
model.add(layers.Dropout(0.25))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D())
model.add(layers.Dropout(0.25))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.Conv2D(256, (1, 1), activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dropout(0.25))          
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(128))
model.add(layers.BatchNormalization())  
model.add(layers.Dropout(0.5))         
model.add(layers.Dense(10, activation='softmax'))
 
model.summary()

5 模型训练

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])
 
history = model.fit(train_image, train_label, epochs=30, batch_size=128)

 

model.evaluate(test_image, test_label)  # [0.7276686429977417, 0.7903000116348267]

  •  存在一定的过拟合

6 一个简化的模型

  • 注意输出层前需要添加全连接层, 不然报错.  
# 定义神经网络
model = keras.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation = 'relu', input_shape = (32, 32, 3)))
model.add(layers.GlobalAveragePooling2D())        
model.add(layers.Dense(10, activation='softmax'))
# 配置网络
model.compile(optimizer = 'adam',
              loss = 'sparse_categorical_crossentropy',
              metrics = ['acc'])
# 训练模型
histroy = model.fit(train_image, train_label, epochs=30, batch_size=128)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/377041.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTML基础语法

一 前端简介构成语言说明结构HTML页面元素和内容表现CSS网页元素的外观和位置等页面样式(美化)行为JavaScript网页模型的定义和页面交互二 HTML1.简介HTML(Hyper Text Markup Language):超文本标记语言。网页结构整体&…

Kubernetes05: Pod

Kubernetes05: Pod 1、概述 1)最小部署的单元 2)K8s不会处理容器,而是Pod,Pod里边包含多个容器(一组容器的集合) 3)一个Pod中的容器共享一个网络命名空间 4) Pod是短暂存在的东西(重…

使用shiroshiro整合其他组件

什么是shiro? 一款apache公司出品的Java安全框架,主要用于设计针对应用程序的保护,使用shiro可以完成认证、授权、加密、会话管理等。保证系统稳定性、数据安全性 优势:易于使用、易于理解、兼容性强(可以与其他框架集…

SE-SSD论文阅读

摘要 本文提出了一种基于自集成单级目标检测器(SE-SSD)的室外点云三维目标检测方法。我们的重点是利用我们的公式约束开发软目标和硬目标来联合优化模型,而不引入额外的计算在推理中。具体来说,SE-SSD包含一对teacher 和student ssd,在其中我…

Mac 安装 Java 反编译工具 JD-GUI

Mac 安装 Java 反编译工具 JD-GUI JD-GUI 是一款 Java 反编译工具,可以方便的将编译好的 .class 文件反编译为 .java 源码文件,用于开发调试、源码学习等。 官网地址:http://java-decompiler.github.io Git 地址:https://github…

直播美颜sdk是什么?它是怎么让用户”变美“的?

如今,直播美颜sdk、手机摄影、短视频以及社交软件的盛行,让“拍照”成为人们日常生活中不可或缺的一部分。随着直播美颜sdk技术的不断升级,手机摄影的质量也越来越高。有统计数据显示,2018年中国智能手机用户已经达到了7亿人&…

美国最新调查显示 50% 企业已在用 ChatGPT,其中 48% 已让其代替员工,你怎么看?

美国企业开始使用ChatGPT,我认为这不是什么新闻。 如果美国的企业现在还不使用ChatGPT,那才是个大新闻。 据新闻源显示,已经使用chatGPT的企业中,48%已经让其代替员工工作。 ChatGPT的具体职责包括:客服、代码编写、招…

HTB-remote

HTB-remote信息搜集开机提权信息搜集 nmap 较为感兴趣的端口: 2180nfs 首先尝试21端口,可以看到并没有文件在ftp服务器里面,而且也无法上传文件。 80端口。 在contact里面找到了能够登录的网站。 经过简单的测试发现可能不存在sql注…

逆向、安全、工具集

0、安卓逆向环境 r0env 原味镜像介绍文章:https://mp.weixin.qq.com/s/gBdcaAx8EInRXPUGeJ5ljQ 原味镜像介绍视频:https://www.bilibili.com/video/BV1qQ4y1R7wW/ 百度盘:链接:https://pan.baidu.com/s/1anvG0Ol_qICt8u7q5_eQJw 提取码:3x2a …

【Spring源码】Spring AOP的核心概念

废话版什么是AOP关于什么是AOP,这里还是要简单介绍下AOP,Aspect Oriented Programming,面向切面编程,通过预编译和运行期间提供动态代理的方式实现程序功能的统一维护,使用AOP可以降低各个部分的耦合度,提高…

openfeign负载均衡策略 | Spring Cloud 5

一、Spring Cloud LoadBalancer介绍 Spring Cloud LoadBalancer是Spring Cloud官网提供的一个客户端负载均衡器,功能类似于Ribbon。在Spring Cloud Nacos 2021移除了中Ribbon组件,Spring Cloud在Spring Cloud Commons项目中,添加了Spring Cl…

华为OD机试题,用 Java 解【N 进制减法】问题

最近更新的博客 华为OD机试题,用 Java 解【停车场车辆统计】问题华为OD机试题,用 Java 解【字符串变换最小字符串】问题华为OD机试题,用 Java 解【计算最大乘积】问题华为OD机试题,用 Java 解【DNA 序列】问题华为OD机试 - 组成最大数(Java) | 机试题算法思路 【2023】使…

Linux | 分布式版本控制工具Git【版本管理 + 远程仓库克隆】

文章目录一、前言二、有关git的相关历史介绍三、Git版本管理1、感性理解 —— 大学生实验报告2、程序员与产品经理3、张三的CEO之路 —— 版本管理工具的诞生四、如何在Linux上使用Git1、创建仓库2、将仓库克隆到本地3、git三板斧① git add② git commit③ git push4、有关git…

yarn run serve报错Error: Cannot find module ‘@vue/cli-plugin-babel‘ 的解决办法

问题概述 关于这个问题,是在构建前端工程的时候遇到的,项目构建完成后,“yarn run serve”启动项目时,出现的问题:“ Error: Cannot find module ‘vue/cli-plugin-babel‘ ” 如下图: 具体信息如下&…

(24秋招笔试准备)回溯专题--代码随想录刷题记录

回溯算法理论基础回溯三部曲:编辑切换为居中添加图片注释,不超过 140 字(可选)组合问题https://mp.weixin.qq.com/s/OnBjbLzuipWz_u4QfmgcqQ组合总和https://mp.weixin.qq.com/s/HX7WW6ixbFZJASkRnCTC3whttps://mp.weixin.qq.com/…

Linux系统认知——驱动认知

文章目录一、驱动相关概念1.什么是驱动2.被驱动设备分类3.设备文件的主设备号和次设备号4.设备驱动整体调用过程二、基于框架编写驱动代码1.驱动代码框架2.驱动代码的编译和测试三、树莓派I/O口驱动的编写1.微机的总线地址、物理地址、虚拟地址介绍2.通过树莓派芯片手册确定需要…

zabbix部署

文章目录前言一、zabbix简介二、zabbix下载与部署三、部署完成、访问前端测试前言 一、zabbix简介 Zabbix 是一个企业级分布式开源监控解决方案。Zabbix 软件能够监控众多网络参数和服务器的健康度、完整性。Zabbix 使用灵活的告警机制,允许用户为几乎任何事件配置…

数据结构与算法——4时间复杂度分析(常见的大O阶)

这篇文章是时间复杂度分析的第二篇。在前一篇文章中,我们从0推导出了为什么要用时间复杂度,时间复杂度如何分析以及时间复杂度的表示三部分内容。这篇文章,是对一些常用的时间复杂度进行一个总结,相当于是一个小结论 1.常见的大O…

【LeetCode】剑指 Offer(11)

目录 题目:剑指 Offer 29. 顺时针打印矩阵 - 力扣(Leetcode) 题目的接口: 解题思路: 代码: 过啦!!! 写在最后: 题目:剑指 Offer 29. 顺时针…

西电计算机通信与网络(计网)简答题计算题核心考点汇总(期末真题+核心考点)

文章目录前言一、简答计算题真题概览二、网桥,交换机和路由器三、ARQ协议四、曼彻斯特编码和差分曼彻斯特编码五、CRC六、ARP协议七、LAN相关协议计算前言 主要针对西安电子科技大学《计算机通信与网络》的核心考点进行汇总,包含总共26章的核心简答。 【…