《Keras深度学习:入门、实战与进阶》CIFAR-10图像识别

news2024/11/16 19:49:45

本文摘自《Keras深度学习:入门、实战与进阶》。
https://item.jd.com/10038325202263.html
在这里插入图片描述
这个数据集由Alex Krizhevsky、Vinod Nair和Geoffrey Hinton收集整理,共包含了60000张32×32的彩色图像,50000张用于训练模型、10000张用于评估模型。可以从其主页(http://www.cs.toronto.edu/~kriz/cifar.html)下载。共有10个类别,它们是:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车。每个分类有6000个图像。
在这里插入图片描述

1、加载CIFAR-10数据

Keras提供了dataset_cifar10()函数用于下载或读取CIFAR-10数据。第一次运行dataset_cifar10()时,程序会检查是否有cifar-10-batches-py.tar.gz文件,如果还没有,就会下载文件,并且解压下载的文件。第一次运行因为需要下载文件,所以运行时间可能会比较长,之后就可以直接从本地加载数据,用于神经网络模型的训练。
如果是Windows环境,文件将存放在C:\Users\用户名\Documents.keras\datasets中。我们来查看解压后的cifar-10-batches-py目录下的内容。

# 查看cifar-10目录下的文件
> file <- 'C:/Users/Daniel/Documents/.keras/datasets/cifar-10-batches-py'
> list.files(file) # 查看目录下文件
[1] "batches.meta" "data_batch_1" "data_batch_2" "data_batch_3" "data_batch_4"
[6] "data_batch_5" "readme.html"  "test_batch"

CIFAR-10数据集分为训练集和测试集两部分。训练集构成了5个训练批次(data_batch_1、data_batch_2、data_batch_3、data_batch_4、data_batch_5),每一批次10000张图。另外用于测试的10000张图单独构成一批(test_batch)。注意一个训练批次中的各类图像数量并不一定相同,总的训练样本包含来自每一类的5000张图。数据导入时,会直接被分割成训练集和测试集两部分,训练和测试数据又由图像数据和标签所组成。

> library(keras)
> c(c(x_train,y_train),c(x_test,y_test)) %<-% dataset_cifar10()
> # 查看数据维度
> dim(x_train);dim(x_test)
[1] 50000    32    32     3
[1] 10000    32    32     3
> dim(y_train);dim(y_test)
[1] 50000     1
[1] 10000     1

train训练数据集有50000项,test测试数据集10000项。x_train和x_test是四维数组,第一维是样本数,第二、三维是指图像大小为32×32,第四维是RGB三原色,所以是3。y_train和y_test是矩阵(二维数组),第一维是样本数,第二维是图像数据的实际真实值。每一个数字代表一种图像类别的名称:0:飞机(airplane)airplane、1:汽车(automobile)automobile、2:鸟(bird)bird、3:猫(cat)、43:鹿(deer)deer、4:dog、5:狗(dog)、6:青蛙(frog)frog、7:马(horse)horse、8:船(ship)ship、9:卡车(truck)truck。
运行以下程序代码,绘制train数据集中前10张图像

> # 绘制前10张图像
> label_dict <- data.frame('label' = 0:9,
+                          'name' = c("airplane","automobile","bird","cat","deer",
+                                     "dog","frog","horse","ship","truck"))
> 
> par(mfrow=c(2,5))
> for(i in 1:10){
+   plot(as.raster(x_train[i,,,],max=255))
+   title(main = paste0(i-1,",",
+                       label_dict[label_dict$label==y_train[i],2]))
+ }
> par(mfrow=c(1,1))

在这里插入图片描述

2、CIFAR-10数据预处理

为了将数据送入卷积神经网络模型进行训练与预测,必须进行数据的预处理。前面的维度分析可知,x_train和x_test的图像数据已经是四维数组,符合卷积神经网络模型的维度要求。

> x_train <- x_train / 255
> x_test <- x_test / 255
> min(x_train);max(x_train)
[1] 0
[1] 1
> min(x_test);max(x_test)
[1] 0
[1] 1

对于CIFAR-10数据集,我们希望预测图像的类型,例如“船”图像的label是8,经过独热编码(One-Hot Encoding)转换为0000000010,10个数字正好对应输出层10个神经元。可以利用to_categorical()函数进行转换。

> y_train_onehot <- to_categorical(y_train,num_classes = 10)
> y_test_onehot <- to_categorical(y_test, num_classes = 10)
> dim(y_train_onehot)
[1] 50000    10
> dim(y_test_onehot)
[1] 10000    10

3、构建简单卷积神经网络识别CIFAR-10图像

首先构建一个简单的卷积神经网络,来验证卷积神经网络在这个数据集上的性能,并以此为基础对网络进行优化,逐步提高模型的准确度。
这个简单的卷积神经网络具有两个卷积层、一个最大值池化层、一个Flatten层和一个全连接层,网络拓扑结构如下:
卷积层,具有32个特征图,卷积核大小为3×3,激活函数为Relu。
Dropout概率为20%的Dropout层。
卷积层,具有32个特征图,卷积核大小为3×3,激活函数为Relu。
Dropout概率为20%的Dropout层。
采样因子(pool_size)为2×2的最大值池化层。
Flatten层。
具有512个神经元和ReLU激活函数的全连接层。
Dropout概率为50%的Dropout层。
具有10个神经元的输出层,激活函数为softmax。
编译模型时,采用RMSProp优化器,categorical_crossentropy作为损失函数,同时采用准确率(accuracy)来评估模型的性能。
构建模型build_simple_cnn ()程序代码如下。

> build_simple_cnn <- function(X=trainx) {
+   model <- keras_model_sequential() %>%
+     layer_conv_2d(filters = 32, 
+                   kernel_size = c(3,3),
+                   activation = 'relu',
+                   input_shape = dim(X)[-1]) %>%
+     layer_dropout(rate = 0.2) %>%
+     layer_conv_2d(filters = 32, 
+                   kernel_size = c(3,3),
+                   activation = 'relu') %>%
+     layer_dropout(rate = 0.2) %>%
+     layer_max_pooling_2d(pool_size = c(2,2)) %>%
+     layer_flatten() %>%
+     layer_dense(units = 512, activation = 'relu') %>% 
+     layer_dropout(rate = 0.5) %>%
+     layer_dense(units = 10, activation = 'softmax')
+   # Compile
+   model %>% compile(
+     loss = 'categorical_crossentropy',
+     optimizer = optimizer_rmsprop(),
+     metrics = 'accuracy')
+   model
+ }

模型构建后,使用fit()函数进行模型训练。将训练周期参数epochs设置为25,batch_size参数为256,validation_split参数为0.2,说明从训练样本中抽取20%作为验证集。`

> simple_cnn_model <- build_simple_cnn(x_train)
> history <- simple_cnn_model %>%
+   fit(x_train,
+       y_train_onehot,
+       epochs = 25,
+       batch_size = 256,
+       validation_split = 0.2)
> plot(history)

在这里插入图片描述
经过30个训练周期后,训练集的准确率为93%,验证集的准确率为70%,出现过拟合现象。可使用当监测值不再改善时将终止训练的callback_early_stopping()回调函数来监控模型,防止出现过拟合现象。
利用训练好的简单卷积神经网络模型对测试进行预测,并查看混淆矩阵。

> pred <- simple_cnn_model %>% predict_classes(x_test)
> t <- table(Actual = y_test,Predicted = pred)
> t
      Predicted
Actual   0   1   2   3   4   5   6   7   8   9
     0 788  24  28  18  27   1  12  10  55  37
     1  23 817   5  13   3   2   8   4  27  98
     2 100  11 470  85 117  68  61  47  22  19
     3  37  18  46 525 102 137  42  38  22  33
     4  34   4  33  63 691  33  47  69  13  13
     5  23   9  38 226  63 535  18  57  14  17
     6  14  12  28  73  57  25 755   8  11  17
     7  27   4  25  43  78  40   4 739   8  32
     8  77  46   7  19   5   2   4   6 795  39
     9  44  98   7  14   4   1   6  13  24 789

模型对汽车(1:automobile)的预测能力最好,有817个样本被正确预测,准确率超过81%;其次是船(8:ship),有795个样本被正确预测。
最后,让我们绘制实际是鸟,但预测错误的50张图像

> ind <- which(as.vector(y_test)==2 & pred != 2) # 提取实际为2,但预测不为2的下标集
> # 绘制预测错误的图像
> par(mfrow=c(5,10)) 
> for(i in 1:50){
+   plot(as.raster(x_test[ind[i],,,]))
+   title(main = paste0(label_dict[label_dict$label==y_test[ind[i]],2],">>",
+                   label_dict[label_dict$label==pred[ind[i]],2]))
+ 
+ }
> par(mfrow=c(1,1))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/339746.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JUC并发编程学习笔记(一)——知识补充(Threadlocal和引用类型)

强引用、弱引用、软引用、虚引用 Java执行 GC(垃圾回收)判断对象是否存活有两种方式&#xff0c;分别是引用计数法和引用链法(可达性分析法)。 **引用计数&#xff1a;**Java堆中给每个对象都有一个引用计数器&#xff0c;每当某个对象在其它地方被引用时&#xff0c;该对象的…

文献阅读:Scaling Instruction-Finetuned Language Models

文献阅读&#xff1a;Scaling Instruction-Finetuned Language Models 1. 文章简介2. 实验 1. 数据集 & 模型 1. 数据集考察2. 使用模型 2. scale up对模型效果的影响3. CoT对模型效果的影响4. 不同模型下Flan的影响5. 开放接口人工标注指标 3. 结论 文献链接&#xff1a;…

【C++】类和对象(一)

目录一、面向过程和面向对象初步认识二、类的引入三、类的定义四、类的访问限定符及封装4.1、访问限定符4.2、封装五、类的作用域六、类的实例化七、类对象的大小八、this指针8.1、this指针的引出8.2、this指针的特性8.3、C语言和C实现Stack的对比一、面向过程和面向对象初步认…

XSS漏洞,通过XSS实现网页挂马

**今天讲下通过XSS实现网页挂马~*&#xff0c;目的是了解安全方面知识&#xff0c;提升生活网络中辨别度 原理&#xff1a; 实验分为两部分&#xff1a; 1、通过Kali linux&#xff0c;利用MS14_064漏洞&#xff0c;制作一个木马服务器。存在该漏洞的用户一旦通过浏览器访问木…

C语言(C文件处理函数和文件指针)

C语言有很多文件操作函数&#xff0c;这里我们挑了一些重要的开始讲&#xff0c;首先说下这些函数都定义在stdio.h头文件中 目录 一.文件指针 二.文件处理函数 1.fopen&#xff08;打开文件&#xff09; 2.fclose(关闭文件) 3.getc和putc(从文件指针读取字符) 4.I/O工作…

「C++控制台生存游戏」暗黑体素 DarkVoxel 控制台版

“《只有作者能看懂的一款游戏》” 刚进高中前开始写的一款抽象的生存游戏 BUG很多请见谅 ###【点击此处&#xff0c;免费畅玩】### 类似泰拉瑞亚的一款游戏 『暗黑体素 DarkVoxel』 直接上图&#xff01; 用控制台写出如此奇葩的生存游戏&#xff0c;可谓世间少有。 操作…

2022黑马Redis跟学笔记.实战篇(二)

2022黑马Redis跟学笔记.实战篇 二实战篇Redis开篇导读4.1短信登录4.1.1. 搭建黑马点评项目一、导入黑马点评项目二、导入SQL三、有关当前模型四、导入后端项目相关依赖配置redis和mysql连接项目组成概述关闭Linux防火墙五、导入前端工程六、 运行前端项目4.1.2. 基于Session实现…

选购交换机的参数依据和主要的参数指标详解

如何选购交换机&#xff1f;用什么交换机&#xff1f;在选购交换机时交换机的优劣无疑十分的重要&#xff0c;而交换机的优劣要从总体构架、性能和功能三方面入手。交换机选购时。性能方面除了要满足RFC2544建议的基本标准&#xff0c;即吞吐量、时延、丢包率外&#xff0c;随着…

网络是怎么连接笔记(一)WEB浏览器

文章目录介绍生成HTTP请求消息向DNS服务器查询WEB服务的IP地址全世界DNS服务器的大接力委托协议栈发送消息介绍 互联网整个消息传递流程 生成HTTP请求消息向DNS服务器查询WEB服务的IP地址然后DNS服务器进行查询IP地址委托协议给对应IP发送消息 生成HTTP请求消息 整个网络发…

Spring面试重点(三)——AOP循环依赖

Spring面试重点 AOP 前置通知&#xff08;Before&#xff09;&#xff1a;在⽬标⽅法运行之前运行&#xff1b;后置通知&#xff08;After&#xff09;&#xff1a;在⽬标⽅法运行结束之后运行&#xff1b;返回通知&#xff08;AfterReturning&#xff09;&#xff1a;在⽬标…

2023年前端面试知识点总结(CSS篇)

近期整理了一下高频的前端面试题&#xff0c;分享给大家一起来学习。如有问题&#xff0c;欢迎指正&#xff01; 1. 对CSS盒模型的理解 CSS3的盒模型有两种盒子模型&#xff1a;标准盒子模型、IE盒子模型 盒模型都是由四个部分组成的&#xff0c;分别是content&#xff08;内容…

layui框架学习(6:基础菜单)

菜单是应用系统的必备元素&#xff0c;虽然网页中的导航也能作为菜单使用&#xff0c;但菜单和导航的样式和用途有所不同&#xff08;不同之处详见参考文献5&#xff09;。Layui中用不同的预设类定义菜单和导航的样式&#xff0c;同时二者依赖的模块也不一样。本文主要学习和记…

Vue (3)

文章目录1. 数据代理1.1 回顾1.2 开始2. 事件处理2.1 v-on:click 点击事件2.2 事件修饰符2.3 键盘事件3. 计算属性3.1 插值语法实现3.2 methods实现3.3 计算属性实现4. 监视属性4.1 深度监视4.2 监视属性的简写形式4.3 watch 与 computed 对比1. 数据代理 在学习 数据代理 时 先…

SQL数据查询——单表查询和排序

文章目录一、单表查询1.查询列1&#xff09;查询全部列指定列2&#xff09;查询经过计算的值3&#xff09;列的别名2.查询元组1&#xff09;消除取值重复的行(DISTINCT)2&#xff09;条件查询(WHERE)3.空值参与运算4.着重号二、排序(ORDER BY子句)一、单表查询 单表查询指仅涉及…

Webpack的知识要点

在前端开发中&#xff0c;一般情况下都使用 npm 和 webpack。   npm是一个非常流行的包管理工具&#xff0c;帮助开发者管理项目中使用的依赖库和工具。它可以方便地为项目安装第三方库&#xff0c;并在项目开发过程中进行版本控制。   webpack是一个模块打包工具&#xff…

C语言深度剖析之程序环境和预处理

1.程序的翻译环境和执行环境 第一种是翻译环境&#xff0c;在这个环境中源代码被转换为可执行的机器指令 第二种是执行环境&#xff0c;它用于实际执行代码 2.翻译环境 分为四个阶段 预编译阶段 &#xff0c;编译&#xff0c;汇编&#xff0c;链接 程序编译过程&#xff1a;多个…

使用vue3,vite,less,flask,python从零开始学习硅谷外卖(16-40集)

严正声明&#xff01; 重要的事情说一遍&#xff0c;本文章仅供分享&#xff0c;文章和代码都是开源的&#xff0c;严禁以此牟利&#xff0c;严禁侵犯尚硅谷原作视频的任何权益&#xff0c;我知道学习编程的人各种各样的心思都有&#xff0c;但这不是你对开源社区侵权的理由&am…

iptables防火墙之SNAT与DNAT

目录 1、SNAT策略概述 1.SNAT策略的典型应用环境 2.SNAT策略的原理 3.SNAT工作原理 4.SNAT转换前提条件 5.开启SNAT命令 6.SNAT转换 2.SNAT示例 1. 配置网关服务器 2.Xshell 连接192.168.100.100 3.DNAT策略及应用 1. DNAT策略概述 2.DNAT 策略的应用 3.DNAT转换前提条件…

看完这篇 教你玩转渗透测试靶机vulnhub——Hack Me Please: 1

Vulnhub靶机Hack Me Please: 1渗透测试详解Vulnhub靶机介绍&#xff1a;Vulnhub靶机下载&#xff1a;Vulnhub靶机安装&#xff1a;Vulnhub靶机漏洞详解&#xff1a;①&#xff1a;信息收集&#xff1a;②&#xff1a;漏洞利用③&#xff1a;获取反弹shell&#xff1a;④&#x…

how https works?https工作原理

简单一句话&#xff1a; https http TLShttps 工作原理&#xff1a;HTTPS (Hypertext Transfer Protocol Secure)是一种带有安全性的通信协议&#xff0c;用于在互联网上传输信息。它通过使用加密来保护数据的隐私和完整性。下面是 HTTPS 的工作原理&#xff1a;初始化安全会…