基于CNN对彩色图像数据集CIFAR-10实现图像分类--keras框架实现

news2025/2/27 5:32:18

项目地址(kaggle):基于CNN对彩色图像数据集CIFAR-10实现图像分类--keras | Kaggle

项目地址(Colab):https://colab.research.google.com/drive/1gjzglPBfQKuhfyT3RlltCLUPgfccT_G9

 导入依赖

在tensorflow-keras-gpu环境中导入下面依赖:

from keras.datasets import cifar10

from keras import regularizers
from keras.callbacks import ModelCheckpoint
from keras.layers import Conv2D, Activation, BatchNormalization, MaxPooling2D, Dropout, Flatten, Dense
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import optimizers
import numpy as np

 准备训练数据

本次实验使用的是keras提供的CIFAER-10数据集,这些数据集是经过预处理,基本可以当作神经网络的输入直接使用,其中包含5000张32x32大小的彩色训练图像和超过10个类别的标注,以及10000张测试图像。

打印数据集
Keras提供的CIFAR-10数据集已被划分为训练集和测试集,并打印测试集和训练集的形状。

# download and split the data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

print("training data = ", x_train.shape)
print("testing data = ", x_test.shape)

 数据归一化处理

要对图像的像素值进行归一化处理,应将每个像素减去平均值并以所得结果除以标准差。

# Normalize the data to speed up training
mean = np.mean(x_train)
std = np.std(x_train)
x_train = (x_train-mean)/(std+1e-7)
x_test = (x_test-mean)/(std+1e-7)

# let's look at the normalized values of a sample image
x_train[0]

对标签进行one-hot编码 

# one-hot encode the labels in train and test datasets
# we use “to_categorical” function in keras
from keras.utils import to_categorical
num_classes = 10
y_train = to_categorical(y_train,num_classes)
y_test = to_categorical(y_test,num_classes)

# let's display one of the one-hot encoded labels
y_train[0]

构建模型架构 

模型的网络结构配置如下:

(1)之前在一个卷积层后面加一个池化层,而在全新的架构中,将在每两个卷积层后面加一个池化层,这个想法是受到VGGNet的启发

(2)这里的卷积层的dilation_rate设置为3x3,并将池化层的pool_size设置为2x2。

(3)每隔一个卷积层就添加dropout层,舍弃率p的取值为0.2-0.4

(4)在Keras中,L2正则化被添加到卷积层中

# build the model

# number of hidden units variable
# we are declaring this variable here and use it in our CONV layers to make it easier to update from one place
base_hidden_units = 32

# l2 regularization hyperparameter
weight_decay = 1e-4

# instantiate an empty sequential model
model = Sequential()

# CONV1
# notice that we defined the input_shape here because this is the first CONV layer.
# we don’t need to do that for the remaining layers
model.add(Conv2D(base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay), input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(BatchNormalization())

# CONV2
model.add(Conv2D(base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

# CONV3
model.add(Conv2D(2*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())

# CONV4
model.add(Conv2D(2*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.3))

# CONV5
model.add(Conv2D(4*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())

# CONV6
model.add(Conv2D(4*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.4))

# FC7
model.add(Flatten())
model.add(Dense(num_classes, activation='softmax'))

# print model summary
model.summary()

模型摘要如下:

 数据增强

本实验将随意采用旋转、高度、和宽度变换、水平翻转等数据增强技术。处理问题时,请检查看网络没有进行分类或分类结果较差的图像,并尝试理解网络在这些图像上表现不佳的原因,然后提出改进假设并进行试验。分析、试验、评估并重复这个过程,通过纯粹的数据分析和对网络性能的理解来做出决定

# data augmentation
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=False
    )

# compute the data augmentation on the training set
datagen.fit(x_train)

训练模型 

训练模型之前先讨论一些超参数的设置策略。

(1)batch_size:batch_size越大,算法学习的越快。可将初始值设置为64,然后将该值翻倍来加速训练。

(2)epochs:开始时将值设为50,但是发现网络仍在改进,所以不断则更加训练轮数并观察训练结果

(3)optimizer:本实验实验了Adam优化器。因新版本的keras很多优化器找不到配置文件的问题,最终解决Adam优化器配置的问题。


# training
from tensorflow.keras.optimizers import legacy
batch_size = 128
epochs=200

checkpointer = ModelCheckpoint(filepath='model.125epochs.hdf5', verbose=1, save_best_only=True)

# you can try any of these optimizers by uncommenting the line
#optimizer = rmsprop(lr=0.001,decay=1e-6)
optimizer = legacy.Adam(learning_rate=0.0001,decay=1e-6)

#optimizer =keras.optimizers.rmsprop(lr=0.0003,decay=1e-6)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
history = model.fit(datagen.flow(x_train, y_train, batch_size=batch_size), callbacks=[checkpointer],
                steps_per_epoch=x_train.shape[0] // batch_size, epochs=epochs,verbose=2,
                validation_data=(x_test,y_test))

评估模型 

调用Keras的evalute函数来评估模型并打印结果

# evaluating the model
scores = model.evaluate(x_test, y_test, batch_size=128, verbose=1)
print('\nTest result: %.3f loss: %.3f' % (scores[1]*100,scores[0]))

打印学习曲线,分析训练性能

# plot learning curves of model accuracy

pyplot.plot(history.history['accuracy'], label='train')
pyplot.plot(history.history['val_accuracy'], label='test')
pyplot.legend()
pyplot.show()

 

调参

为了提升模型的结果,需要对模型进一步改进:

(1)增加训练的轮数:通过上述效果可以得出模型在125轮之前一直在增加,可将模型的训练轮数进行进一步增加。

(2)使用更深的网络结构:尝试添加更多层来提升模型的复杂度,以增强其学习能力。

(3)降低学习率:通过降低学习率learning_rate的方式使其模型使用更长的时间去学习。

(4)使用不同的CNN架构。

最终我们经过多次调参得到如下结果

序号

batch_size

epochs

learning_rate

Test result

学习曲线

1

128

125

0.0001

86.560

2

128

200

0.0001

87.360

3

256

200

0.001

86.930

4

256

200

0.0001

88.120

5

256

200

0.0003

87.820

 我们经过了五次实验发现当batch_size=256,epochs=200,learning_rate=0.0001的时候,Test result最高,分类效果最好,当然可以继续尝试添加更多层来提升模型的复杂度,以增强其学习能力。

异常问题与解决方案

1、报错:Failed to get convolution algorithm. cudnn failed to initialize

解决办法:在模型前面加上这几句话,意思大概也是运行内存增加

physical_devices = tf.config.experimental.list_physical_devices('GPU')

if len(physical_devices) > 0:

     for k in range(len(physical_devices)):

         tf.config.experimental.set_memory_growth(physical_devices[k], True)

     print('memory growth:', tf.config.experimental.get_memory_growth(physical_devices[k]))

else:

     print("Not enough GPU hardware devices available")

2、报错:lr "参数已被弃用,请使用 "learning_rate "参数。 super().__init__(name, **kwargs)

解决办法:将lr换为learning_rate

3、报错:`Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.

解决办法:将Model.fit_generator改为Model.fit

4、报错:No module named ‘adam’

解决办法:将keras.optimizers.adam改为legacy.Adam,并重新导入legacy包

5、报错:Image transformations require SciPy. Install SciPy.

解决办法:重新安装SciPy

6、报错----> 3 pyplot.plot(history.history['acc'], label='train')

      4 pyplot.plot(history.history['val_acc'], label='test')

      5 pyplot.legend()

KeyError: 'acc'

解决办法:将acc替换为accuracy;val_acc替换为val_accuracy

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1281697.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

若依的基本使用

演示使用网址:若依管理系统 网站:RuoYi 若依官方网站 |后台管理系统|权限管理系统|快速开发框架|企业管理系统|开源框架|微服务框架|前后端分离框架|开源后台系统|RuoYi|RuoYi-Vue|RuoYi-Cloud|RuoYi框架|RuoYi开源|RuoYi视频|若依视频|RuoYi开发文档|若依开发文档|Java开源框架…

绝地求生在steam叫什么?

绝地求生在Steam的全名是《PlayerUnknowns Battlegrounds》,简称为PUBG。作为一款风靡全球的多人在线游戏,PUBG于2017年3月23日正式上线Steam平台,并迅速成为一部热门游戏。 PUBG以生存竞技为核心玩法,玩家将被投放到一个辽阔的荒…

基于 ESP32 的带触摸显示屏的 RFID 读取器

如何设计一款基于 ESP32 且具有 ILI9341 触摸屏显示屏且适合壁挂式安装的美观 RFID 读取器。 本项目中用到的东西 硬件组件 ESP32 开发套件 C 1 AZ-Touch ESP 套件 1 RFID-RC522 IC卡读写器 1 ​编辑 电线、绕包线 1 详细设计流程 …

结构体实现位段

一.什么是位段 位段的声明和结构是类似的,有两个不同: 位段的成员必须是 int、unsigned int 或 signed int ,在C99中位段成员的类型也可以 选择其他类型。 位段的成员名后边有⼀个冒号和⼀个数字 struct A {int a : 5;int b : 4;int c : 2…

SpringBoot的配置加载优先级

目录 一、背景分析 二、学习资源 三、具体使用 四、一些小技巧 方式一 方式二 一、背景分析 SpringBoot项目在打包之后&#xff0c;其配置文件就在jar包内&#xff0c;如果没有<配置文件优先级>这个机制&#xff0c;那么项目打成jar包之后&#xff0c;如果启动项目…

avue-crud中时间范围选择默认应该是0点却变成了12点

文章目录 一、问题二、解决三、最后 一、问题 在avue-crud中时间范围选择&#xff0c;正常默认应该是0点&#xff0c;但是不知道怎么的了&#xff0c;选完之后就是一直是12点。具体问题如下动图所示&#xff1a; <template><avue-crud :option"option" /&g…

数据结构入门————树(C语言/零基础/小白/新手+模拟实现+例题讲解)

目录 1. 树的概念及其结构 1.1 树的概念&#xff1a; 1.2 树的相关概念&#xff1a; 1.3 树的表示方法&#xff1a; ​编辑 1.4 树的应用&#xff1a; 2. 二叉树的概念及其结构 2.1 概念: 2.2 特点&#xff1a; 2.3 特殊二叉树&#xff1a; 2.4 二叉树的性质&#xf…

shell编程-awk命令详解(超详细)

文章目录 前言一、awk命令介绍1. awk命令简介2. awk命令的基本语法3. 常用的awk命令选项4. 常用的awk内置变量 二、awk命令示例用法1. 打印整行2. 打印特定字段3. 根据条件筛选行4. 自定义分隔符5. 从文件中读取awk脚本 总结 前言 awk命令是一种强大的文本处理工具&#xff0c…

Qt OpenCV 学习(一):环境搭建

对应版本 Qt 5.15.2OpenCV 3.4.9MinGW 8.1.0 32-bit 1. OpenCV 下载 确保安装 Qt 时勾选了 MinGW 编译器 本文使用 MinGW 编译好的 OpenCV 库&#xff0c;无需自行编译 确保下载的 MinGW 和上述安装 Qt 时勾选的 MinGW 编译器位数一致&#xff0c;此处均为 x86/32-bit下载地址…

LongAddr

目录 1. 引言 2. AtomicInteger的局限性 3. AtomicInteger与LongAdder 的性能差异 4.LongAdder 的结构 LongAddr架构 Striped64中重要的属性 Striped64中一些变量或者方法的定义 Cell类 5. 分散热点的原理 具体流程图 6. 在实际项目中的应用 7. 总结 1. 引言 在这一…

【C++练级之路】【Lv.2】类和对象(上)(类的定义,访问限定符,类的作用域,类的实例化,类的对象大小,this指针)

目录 一、面向过程和面向对象初步认识二、类的引入三、类的定义四、类的访问限定符及封装4.1 访问限定符4.2 封装 五、类的作用域六、类的实例化七、类的对象大小的计算7.1 类对象的存储方式猜测7.2 如何计算类对象的大小 八、类成员函数的this指针8.1 this指针的引出8.2 this指…

课题学习(十四)----三轴加速度计+三轴陀螺仪传感器-ICM20602

本篇博客对ICM20602芯片进行学习&#xff0c;目的是后续设计一个电路板&#xff0c;采集ICM20602的数据&#xff0c;通过这些数据对各种姿态解算的方法进行仿真学习。 一、 ICM20602介绍 1.1 初识芯片 3轴陀螺仪&#xff1a;可编程全刻度范围(FSR)为250 dps&#xff0c;500 d…

Proteus仿真--基于ADC0832设计的两路电压表

本文介绍基于ADC0832实现的双路电压表采集设计&#xff08;完整仿真源文件及代码见文末链接&#xff09; 仿真图如下 采集芯片选用ADC0832&#xff0c;电压显示在LCD1602液晶显示屏上 仿真运行视频 Proteus仿真--基于ADC0832设计的两路电压表 附完整Proteus仿真资料代码资料…

自动驾驶学习笔记(十四)——感知算法

#Apollo开发者# 学习课程的传送门如下&#xff0c;当您也准备学习自动驾驶时&#xff0c;可以和我一同前往&#xff1a; 《自动驾驶新人之旅》免费课程—> 传送门 《Apollo Beta宣讲和线下沙龙》免费报名—>传送门 文章目录 前言 感知算法 开发过程 测试和评价 前言…

【Linux】cp 命令使用

cp 命令 cp&#xff08;英文全拼&#xff1a;copy file&#xff09;命令主要用于复制文件或目录。 著者 由Torbjorn Granlund、David MacKenzie和Jim Meyering撰写。 语法 cp [选项]... [-T] 源文件 目标文件或&#xff1a;cp [选项]... 源文件... 目录或&#xff1a;cp [选…

【Docker】容器数据持久化及容器互联

一、Docker容器的数据管理 1.1、什么是数据卷 数据卷是经过特殊设计的目录&#xff0c;可以绕过联合文件系统&#xff08;UFS&#xff09;&#xff0c;为一个或者多个容器提供访问&#xff0c;数据卷设计的目的&#xff0c;在于数据的永久存储&#xff0c;它完全独立于容器的…

深入理解GMP模型

1、GMP模型的设计思想 1&#xff09;、GMP模型 GMP分别代表&#xff1a; G&#xff1a;goroutine&#xff0c;Go协程&#xff0c;是参与调度与执行的最小单位M&#xff1a;machine&#xff0c;系统级线程P&#xff1a;processor&#xff0c;包含了运行goroutine的资源&#…

canvas绘制单个圆和同心圆

代码实现&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdev…

CCKS2023-面向上市公司主营业务的实体链接评测-亚军方案

赛题分析 大赛地址 https://tianchi.aliyun.com/competition/entrance/532097/information 任务描述 本次任务主要针对上市公司的主营业务进行产品实体链接。需要获得主营业务中的产品实体&#xff0c;将该实体链接到产品数据库中的某一个标准产品实体。产品数据库将发布在竞赛…

HCIP —— 双点重发布 + 路由策略 实验

目录 实验拓扑&#xff1a; 实验要求&#xff1a; 实验配置&#xff1a; 1.配置IP地址 2.配置动态路由协议 —— RIP 、 OSPF R1 RIP R4 OSPF R2 配置RIP、OSPF 双向重发布 R3配置RIP、OSPF 双向重发布 3.查询路由表学习情况 4.使用路由策略控制选路 R2 R3 5.检…