CIFAR-10数据集详析:使用卷积神经网络训练图像分类模型

news2024/11/24 8:41:31

1.数据集介绍

CIFAR-10 数据集由 10 个类的 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为5个训练批次和1个测试批次,每个批次有10000张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。

总结:

Size(大小): 32×32 RGB图像 ,数据集本身是 BGR 通道
Num(数量): 训练集 50000 和 测试集 10000,一共60000张图片
Classes(十种类别): plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)
在这里插入图片描述

下载链接

来自博主(Dream是个帅哥)的分享:
链接: https://pan.baidu.com/s/1gKazlkk108V_1nrc68VoSQ 提取码: 0213

数据集文件夹

在这里插入图片描述

CIFAR-100数据集(拓展)

这个数据集与CIFAR-10类似,只不过它有100个类,每个类包含600个图像。每个类有500个训练图像和100个测试图像。CIFAR-100中的100个子类被分为20个大类。每个图像都有一个“fine”标签(它所属的子类)和一个“coarse”标签(它所属的大类)。

CIFAR-10数据集与MNIST数据集对比

  • 维度不同:CIFAR-10数据集有4个维度,MNIST数据集有3个维度(CIRAR-10的四维: 一次的样本数量, 图片高, 图片宽, 图通道数 -> N H W C;MNIST的三维: 一次的样本数量, 图片高, 图片宽 -> N H W)
  • 图像类型不同:CIFAR-10数据集是RGB图像(有三个通道),MNIST数据集是灰度图像,这也是为什么CIFAR-10数据集比MNIST数据集多出一个维度的原因。
  • 图像内容不同:CIFAR-10数据集展示的是各种不同的物体(猫、狗、飞机、汽车…),MNIST数据集展示的是不同人的手写0~9数字。

2.数据集读取

读取数据集

选取data_batch_1可视化其中一张图:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')
print(dict)

输出结果:
一批次的数据集中有4个字典键,我们需要用到的就是 数据标签 和 数据内容(10000×32×32×3,10000张32×32大小为rgb三通道的图片)
在这里插入图片描述
输出的是一个字典:

{
b’batch_label’: b’training batch 1 of 5’,
b’labels’: [6, 9 … 1,5],
b’data’: array([[ 59, 43, …, 84, 72],…[ 62, 61, 60, …, 130, 130, 131]], dtype=uint8),
b’filenames’: [b’leptodactylus_pentadactylus_s_000004.png’,…b’cur_s_000170.png’]

}

其中,各个代表的意思如下:
b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

读取类型

print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))

输出结果:

<class ‘bytes’>
<class ‘list’>
<class ‘numpy.ndarray’>
<class ‘list’>

读取图片

img = dict[b'data']
print(img.shape)

输出结果:(10000, 3072),其中 3072 = 32 * 32 * 3 (图片 size)

3.数据集调用

TensorFlow 调用

from tensorflow.keras.datasets import cifar10

(x_train,y_train), (x_test, y_test) = cifar10.load_data()

本地调用

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
dict = unpickle('D:\PycharmProjects\model-fuxian\CIFAR\cifar-10-batches-py\data_batch_1')

4.卷积神经网络训练

此处参考:传送门

1.指定GPU

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']

2.加载数据

cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))

3.数据预处理

X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)

4.建立模型

adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数

model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息

#配置模型训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])

5.训练模型

批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)

history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)

6.评估模型

model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力

#保存整个模型
model.save('CIFAR10_CNN_weights.h5')

7.结果可视化

print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率

plt.figure(figsize=(10,3))

plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()

plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()

8.使用模型

plt.figure()
for i in range(10):
    num = np.random.randint(1,10000)
    plt.subplot(2,5,i+1)
    plt.axis('off')
    plt.imshow(test_x[num],cmap='gray')
    demo = tf.reshape(X_test[num],(1,32,32,3))
    y_pred = np.argmax(model.predict(demo))
    plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
plt.show()

输出结果:
在这里插入图片描述
在这里插入图片描述
上面的内容分别是训练样本的损失函数值和准确率、测试样本的损失函数值和准确率,可以看到它每次训练迭代时损失函数和准确率的变化,从最后一次迭代结果上看,测试样本的损失函数值达到0.9123,准确率仅达到0.6839。
这个结果并不是很好,我尝试过增加迭代次数,发现训练样本的损失函数值可以达到0.04,准确率达到0.98;但实际上训练模型却产生了越来越大的泛化误差,这就是训练过度的现象,经过尝试泛化能力最好时是在迭代第5次的状态,故只能选择迭代5次。
在这里插入图片描述

训练好的模型文件——直接用

CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——附完整代码训练好的模型文件——直接用:https://download.csdn.net/download/weixin_51390582/88788820

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1415705.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GitHub 一周热点汇总第7期(2024/01/21-01/27)

GitHub一周热点汇总第7期 (2024/01/21-01/27) &#xff0c;梳理每周热门的GitHub项目&#xff0c;离春节越来越近了&#xff0c;不知道大家都买好回家的票没有&#xff0c;希望大家都能顺利买到票&#xff0c;一起来看看这周的项目吧。 #1 rustdesk 项目名称&#xff1a;rust…

Redis 面试题 | 15.精选Redis高频面试题

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

make: *** No rule to make target ‘clean‘. Stop.

项目场景&#xff1a; 在Ubuntu下编写makefile文件编译的时候,出现make: *** No rule to make target ‘clean’. Stop. 问题描述 make: *** No rule to make target ‘clean’. Stop. 解决方案&#xff1a; 原本我makefile文件的名字是MakeFile , 把它改为makefile以后完美运…

再学http

HTTP状态码 1xx 信息性状态码 websocket upgrade 2xx 成功状态码 200 服务器已成功处理了请求204(没有响应体)206(范围请求 暂停继续下载) 3xx 重定向状态码 301(永久) &#xff1a;请求的页面已永久跳转到新的url302(临时) &#xff1a;允许各种各样的重定向&#xff0c;一般…

FlashInternImage实战:使用 FlashInternImage实现图像分类任务(二)

文章目录 训练部分导入项目使用的库设置随机因子设置全局参数图像预处理与增强读取数据设置Loss设置模型设置优化器和学习率调整策略设置混合精度&#xff0c;DP多卡&#xff0c;EMA定义训练和验证函数训练函数验证函数调用训练和验证方法 运行以及结果查看测试完整的代码 在上…

利用STM32CubeMX和Keil模拟器,3天入门FreeRTOS(5.3) ——递归锁

前言 &#xff08;1&#xff09;FreeRTOS是我一天过完的&#xff0c;由此回忆并且记录一下。个人认为&#xff0c;如果只是入门&#xff0c;利用STM32CubeMX是一个非常好的选择。学习完本系列课程之后&#xff0c;再去学习网上的一些其他课程也许会简单很多。 &#xff08;2&am…

物联网协议Coap之C#基于Mozi的CoapClient调用解析

目录 前言 一、CoapClient相关类介绍 1、CoapClient类图 2、CoapClient的设计与实现 3、SendMessage解析 二、Client调用分析 1、创建CoapClient对象 2、实际发送请求 3、Server端请求响应 4、控制器寻址 总结 前言 在之前的博客内容中&#xff0c;关于在ASP.Net Co…

循序渐进,学会用pyecharts绘制桑基图

循序渐进&#xff0c;学会用pyecharts绘制桑基图 桑基图介绍 桑基图是比较冷门的可视化图形&#xff0c;知道的人不多&#xff0c;但它的可视化效果很惊艳&#xff0c;以后肯定会有越来越多的人使用&#xff0c;我平时使用桑基图&#xff0c;主要是用其绘制可视化图形做PPT。…

签到业务流程

1.技术选型 Redis主写入查询&#xff0c;Mysql辅助查询&#xff0c;传统签到多数都是直接采用mysql为存储DB,在大数据的情况下数据库的压力较大.查询速率也会随着数据量增大而增加.所以在需求定稿以后查阅了很多签到实现方式,发现用redis做签到会有很大的优势.本功能主要用到r…

JVM系列——基础知识

Java运行区域 程序计数器&#xff08;Program Counter Register&#xff09; 程序计数器是一块较小的内存空间&#xff0c;它可以看作是当前线程所执行的字节码的行号指示器。在Java虚拟机的概念模型里[1]&#xff0c;字节码解释器工作时就是通过改变这个计数器的值来选取下一…

STM正点mini-跑马灯

一.库函数版 1.硬件连接 &#xff27;&#xff30;&#xff29;&#xff2f;的输出方式&#xff1a;推挽输出 &#xff29;&#xff2f;口输出为高电平时&#xff0c;P-MOS置高&#xff0c;输出为&#xff11;&#xff0c;LED对应引脚处为高电平&#xff0c;而二极管正&#…

[Tomcat] [从安装到关闭] MAC部署方式

安装Tomcat 官网下载&#xff1a;Apache Tomcat - Apache Tomcat 9 Software Downloads 配置Tomcat 1、输入cd空格&#xff0c;打开Tomca目录&#xff0c;把bin文件夹直接拖拉到终端 2、授权bin目录下的所有操作&#xff1a;终端输入[sudo chmod 755 *.sh]&#xff0c;回车 …

HCS-华为云Stack-FusionSphere

HCS-华为云Stack-FusionSphere FusionSphere是华为面向多行业客户推出的云操作系统解决方案。 FusionSphere基于开放的OpenStack架构&#xff0c;并针对企业云计算数据中心场景进行设计和优化&#xff0c;提供了强大的虚拟化功能和资源池管理能力、丰富的云基础服务组件和工具…

C++类和对象——深拷贝与浅拷贝详解

目录 1.深拷贝和浅拷贝是什么 2.案例分析 完整代码 1.深拷贝和浅拷贝是什么 看不懂没关系&#xff0c;下面有案例分析 2.案例分析 浅拷贝可能会导致堆区的内存重复释放 一个名为person的类里面有年龄和指向身高的指针这两个成员。 当我们执行到person p2&#xff08;p1&am…

翻译: GPT-4 with Vision 升级 Streamlit 应用程序的 7 种方式一

随着 OpenAI 在多模态方面的最新进展&#xff0c;想象一下将这种能力与视觉理解相结合。 现在&#xff0c;您可以在 Streamlit 应用程序中使用 GPT-4 和 Vision&#xff0c;以&#xff1a; 从草图和静态图像构建 Streamlit 应用程序。帮助你优化应用的用户体验&#xff0c;包…

深度学习-使用Labelimg数据标注

数据标注是计算机视觉和机器学习项目中至关重要的一步&#xff0c;而使用工具进行标注是提高效率的关键。本文介绍了LabelImg&#xff0c;一款常用的开源图像标注工具。用户可以在图像中方便而准确地标注目标区域&#xff0c;为训练机器学习模型提供高质量的标注数据。LabelImg…

【VB测绘程序设计】案例8——IF选择结构练习排序(附源代码)

【VB测绘程序设计】案例6——IF选择结构练习排序(附源代码) 文章目录 前言一、界面显示二、程序说明三、程序代码四、数据演示总结前言 本文主要掌握Val()函数转换,inputBox函数、IF条件句的练习,输入3个数,按大到小排序并打印。 一、界面显示 二、程序说明 利用inpu…

node.js Redis SETNX命令实现分布式锁解决超卖/定时任务重复执行问题

Redis SETNX 特性 当然&#xff0c;让我们通过一个简单的例子&#xff0c;使用 Redis CLI&#xff08;命令行界面&#xff09;来模拟获取锁和释放锁的过程。 在此示例中&#xff0c;我将使用键“lock:tcaccount_[pk]”和“status:tcaccount_[pk]”分别表示锁定键和状态键。 获…

搞定App关键词和评论

从关键词优化的三大基本概念走起&#xff01; 关联性 优化师一般如何选择关联性高的关键词呢&#xff1f; 主要思路如下&#xff1a;品牌词-关联词-竞品词-竞品关键词&#xff0c;优先级从前到后依次降低&#xff0c;通过ASO优化工具筛选出合适的关键词。做ASO有一个好处就是…

vusui css 使用,简单明了 适合后端人员 已解决

vusui-cssopen in new window 免除开发者繁复的手写 CSS 样式&#xff0c;让 WEB 前端开发更简单、灵活、便捷&#xff01;如果喜欢就点个 ★Staropen in new window 吧。 移动设备优先&#xff1a; vusui-css 包含了贯穿于整个库的移动设备优先的样式。浏览器支持&#xff1a…