TensorFlow深度学习实战——构建卷积神经网络实现CIFAR-10图像分类

news2025/2/21 5:46:43

TensorFlow深度学习实战——构建卷积神经网络实现CIFAR-10图像分类

    • 0. 前言
    • 1. CIFAR-10 数据集介绍
    • 2. CIFAR-10 图像分类
    • 3. 提升模型性能
      • 3.1 增加网络深度
      • 3.2 数据增强
    • 4. 模型测试
    • 相关链接

0. 前言

我们已经学习了卷积神经网络 (Convolutional Neural Network, CNN) 的基本概念,并使用 TensorFlow 构建 CNN 实现了 MNIST 手写数字分类问题,本节将进一步介绍如何利用 CNNCIFAR-10 数据集进行图像分类。首先构建一个简单的 CNN 模型来进行 CIFAR-10 图像的分类,接着,通过模型优化技术提高分类准确率。介绍了一个完整的从数据准备到模型评估的图像分类实践过程。

1. CIFAR-10 数据集介绍

MNIST 数据集简单的灰度图像不同,CIFAR-10 数据集包含 60,000 张彩色图像,每张图像的尺寸为 32 x 32 像素,每张图像有有三个通道,分为 10 个类别。每个类别包含 6,000 张图像,训练集包含 50,000 张图像,测试集包含 10,000 张图像。随机查看 CIFAR-10 数据集中的数据样本:

数据样本

模型训练的目标是预测 CIFAR-10 数据集中在训练过程中未见过的图像的类别。

2. CIFAR-10 图像分类

(1) 首先,我们导入所需库,定义常量,并加载数据集:

import tensorflow as tf
from tensorflow.keras import datasets, layers, models, optimizers

# CIFAR_10 is a set of 60K images 32x32 pixels on 3 channels
IMG_CHANNELS = 3
IMG_ROWS = 32
IMG_COLS = 32

#constant
BATCH_SIZE = 128
EPOCHS = 20
CLASSES = 10
VERBOSE = 2
VALIDATION_SPLIT = 0.2
OPTIM = tf.keras.optimizers.RMSprop()

(2) 卷积层使用 32 个卷积核,每个卷积核尺寸为 3 x 3。输出尺寸与输入形状相同,为 32 x 32,激活函数采用 ReLU 函数,引入非线性。池化层使用 2 x 2MaxPooling 操作,并且使用概率为 25%Dropout 操作:

#define the convnet 
def build(input_shape, classes):
    model = models.Sequential() 
    model.add(layers.Convolution2D(32, (3, 3), activation='relu',
              input_shape=input_shape))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(0.25)) 

接下来,在展平特征图后添加全连接层,全连接层包含 512 个神经元和 ReLU 激活函数,随后使用概率为 50%Dropout 操作,最后添加一个使用 Softmax 激活函数的全连接层,此全连接层包含 10 个神经元对应 10 个类别,每个类别对应一个输出:

    model.add(layers.Flatten())
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(classes, activation='softmax'))
    return model

(3) 定义网络之后,训练模型。除了训练集和测试集之外,还将数据拆分得到一个验证集。训练集用于构建模型,验证集用于选择性能表现最佳的方法,而测试集用于检查最佳模型在未见过的数据上的表现:

# data: shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = datasets.cifar10.load_data()
# normalize
X_train, X_test = X_train / 255.0, X_test / 255.0
# convert to categorical
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, CLASSES)

model=build((IMG_ROWS, IMG_COLS, IMG_CHANNELS), CLASSES)
model.summary()

# use TensorBoard, princess Aurora!
callbacks = [
    # Write TensorBoard logs to `./logs` directory
    tf.keras.callbacks.TensorBoard(log_dir='./logs')
]

# train
model.compile(loss='categorical_crossentropy', optimizer=OPTIM,
              metrics=['accuracy'])
 
model.fit(X_train, y_train, batch_size=BATCH_SIZE,
          epochs=EPOCHS, validation_split=VALIDATION_SPLIT, 
          verbose=VERBOSE, callbacks=callbacks) 
score = model.evaluate(X_test, y_test, batch_size=BATCH_SIZE, verbose=VERBOSE)
print("\nTest score:", score[0])
print('Test accuracy:', score[1])

运行代码。模型在进行 20epochs 的训练后在测试数据集上能够达到 67.8% 的准确率:

模型训练过程监测

模型在 CIFAR-10 数据集上训练过程的准确率和损失的变化情况如下:

模型训练过程监测

3. 提升模型性能

3.1 增加网络深度

提高性能的一种方式是使用多个卷积层定义更深的网络。接下来,我们将使用以下网络构建块:

1st block: CONV+CONV+MaxPool+Dropout 2nd block: CONV+CONV+MaxPool+Dropout 3rd block: CONV+CONV+MaxPool+Dropout \text{1st block: CONV+CONV+MaxPool+Dropout}\\ \text{2nd block: CONV+CONV+MaxPool+Dropout}\\ \text{3rd block: CONV+CONV+MaxPool+Dropout} 1st block: CONV+CONV+MaxPool+Dropout2nd block: CONV+CONV+MaxPool+Dropout3rd block: CONV+CONV+MaxPool+Dropout
之后是添加全连接层作为输出层。除了输出层外,所有网络层使用的激活函数都是 ReLU 函数。网络中添加 BatchNormalization() 层用于引入正则化:

import tensorflow as tf
from tensorflow.keras import datasets, layers, models, regularizers, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np

def build_model(): 
    model = models.Sequential()
    
    #1st block
    model.add(layers.Conv2D(32, (3,3), padding='same', 
        input_shape=x_train.shape[1:], activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv2D(32, (3,3), padding='same', activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D(pool_size=(2,2)))
    model.add(layers.Dropout(0.2))

    #2nd block
    model.add(layers.Conv2D(64, (3,3), padding='same', activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv2D(64, (3,3), padding='same', activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D(pool_size=(2,2)))
    model.add(layers.Dropout(0.3))

    #3d block 
    model.add(layers.Conv2D(128, (3,3), padding='same', activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv2D(128, (3,3), padding='same', activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPooling2D(pool_size=(2,2)))
    model.add(layers.Dropout(0.4))

    #dense  
    model.add(layers.Flatten())
    model.add(layers.Dense(NUM_CLASSES, activation='softmax'))
    return model

加载并归一化数据集:

EPOCHS=50
NUM_CLASSES = 10
BATCH_SIZE = 128
    

def load_data():
    (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
 
    #normalize 
    mean = np.mean(x_train,axis=(0,1,2,3))
    std = np.std(x_train,axis=(0,1,2,3))
    x_train = (x_train-mean)/(std+1e-7)
    x_test = (x_test-mean)/(std+1e-7)
 
    y_train =  tf.keras.utils.to_categorical(y_train,NUM_CLASSES)
    y_test =  tf.keras.utils.to_categorical(y_test,NUM_CLASSES)

    return x_train, y_train, x_test, y_test

网络定义完成后,训练 50epochs,最后在测试数据集上能够达到 82.60% 的准确率:

(x_train, y_train, x_test, y_test) = load_data()
model = build_model()
model.compile(loss='categorical_crossentropy', 
              optimizer='RMSprop', 
              metrics=['accuracy'])

#train
batch_size = 64
model.fit(x_train, y_train, batch_size=batch_size,
          epochs=EPOCHS, validation_data=(x_test,y_test)) 
score = model.evaluate(x_test, y_test,
                       batch_size=BATCH_SIZE)
print("\nTest score:", score[0])
print('Test accuracy:', score[1])

相对于简单网络,模型性能取得了 14.74% 的提高。

3.2 数据增强

另一种提高性能的方法是使用更多数据训练模型。我们可以取标准的 CIFAR-10 训练集,并用多种类型的转换扩充训练集,包括旋转、水平或垂直翻转、缩放、平移等:

#image augmentation
datagen = ImageDataGenerator(
            rotation_range=30,
            width_shift_range=0.2,
            height_shift_range=0.2,
            horizontal_flip=True,
)
datagen.fit(x_train)

rotation_range 是一个在度数范围 (0-180) 内的值,用于随机旋转图片;width_shiftheight_shift 是用于随机在垂直或水平方向上平移图片的范围;zoom_range 用于随机缩放图片;horizontal_flip 用于随机水平翻转一半的图片;fill_mode 是用于填充因旋转或移位后可能产生的空白区域。
更正式的讲,数据增强是通过对训练数据应用各种变换(如旋转、缩放、翻转、裁剪等)来生成更多样化的数据,从而提高模型的泛化能力和鲁棒性。
数据增强完成后,能够根据标准的 CIFAR-10 集合中生成更多的训练图像:

数据增强结果

训练模型。使用上一小节定义的 ConvNet,使用数据增强生成更多图像,然后进行训练:

#train
batch_size = 64
model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
                    epochs=EPOCHS,
                    verbose=2,validation_data=(x_test,y_test))
#save to disk
model_json = model.to_json()
with open('model.json', 'w') as json_file:
    json_file.write(model_json)
model.save_weights('model.h5') 

#test
scores = model.evaluate(x_test, y_test, batch_size=128, verbose=1)
print('\nTest result: %.3f loss: %.3f' % (scores[1]*100,scores[0])) 

训练过程需要更多时间,因为有更多的训练数据。因此,运行 50epochs 的训练,模型在测试及上能够达到 84% 的准确率:

训练过程

模型在 CIFAR-10 数据集上训练过程的准确率和损失的变化情况如下:模型性能监测

CIFAR-10 数据集上训练的一系列深度学习模型性能能够在线查看。截至目前,最佳模型的准确率为 96.53%

模型准确率

4. 模型测试

假设想要使用刚刚训练的 CIFAR-10 深度学习模型进行图像评估,由于我们保存了模型和权重,所以不需要每次都进行训练:

import numpy as np

from skimage.transform import resize
from imageio import imread

from tensorflow.keras.models import model_from_json
from tensorflow.keras.optimizers import SGD

model_architecture = "cifar10_architecture.json"
model_weights = "cifar10_weights.h5"
model = model_from_json(open(model_architecture).read())
model.load_weights(model_weights)

img_names = ["cat.jpg", "dog.jpg"]
imgs = [resize(imread(img_name), (32, 32)).astype("float32") for img_name in img_names]
imgs = np.array(imgs) / 255
print("imgs.shape:", imgs.shape)

optim = SGD()
model.compile(loss="categorical_crossentropy", optimizer=optim, metrics=["accuracy"])

predictions = np.argmax(model.predict(imgs), axis=1)
print("predictions:", predictions)

使用 imageioimread 加载图像,然后将它们调整为 32 × 32 像素,得到的图像张量的维度为 (32, 32, 3)。之后,将图像张量列表合并成一个单一的张量,并将其归一化到 01.0 之间,最后构建模型,加载训练完成的模型权重进行预测。
为以下两张图像分类,预期输出为类别 3 (猫)和 5 (狗)。

示例图像

相关链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(8)——卷积神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2301729.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

服务器创建conda环境并安装使用jupyter

1.创建conda环境 conda create --name myenv python3.8 conda activate myenv其中 myenv 是您想要创建的环境名称,可以根据需要替换为其他名称。2.安装juypter conda install jupyter3.启动juypter jupyter notebook复制链接到浏览器打开 4.设置jupyter使用的 …

【HarmonyOS Next】鸿蒙监听手机按键

【HarmonyOS Next】鸿蒙监听手机按键 一、前言 应用开发中我们会遇到监听用户实体按键,或者扩展按键的需求。亦或者是在某些场景下,禁止用户按下某些按键的业务需求。 这两种需求,鸿蒙都提供了对应的监听事件进行处理。 onKeyEvent 默认的…

【Spring详解五】bean的加载

五、bean的加载 当我们显示或者隐式地调用 getBean() 时,则会触发加载 bean 阶段。示例代码如下: public class AppTest {Testpublic void MyTestBeanTest() {BeanFactory bf new XmlBeanFactory( new ClassPathResource("spring-config.xml"…

ThinkPHP(TP)如何做安全加固,防webshell、防篡改、防劫持、TP漏洞防护

ThinkPHP是一款非常知名的PHP框架,很多知名CMS系统都是采用TP框架进行二次开发而来,当然ThinkPHP本身也可以直接建站,开源免费、功能强大,深受广大用户喜欢。 虽然ThinkPHP非常优秀,但是为了保障网站安全,我…

MySQL(1)基础篇

执行一条 select 语句,期间发生了什么? | 小林coding 目录 1、连接MySQL服务器 2、查询缓存 3、解析SQL语句 4、执行SQL语句 5、MySQL一行记录的存储结构 Server 层负责建立连接、分析和执行 SQL存储引擎层负责数据的存储和提取。支持InnoDB、MyIS…

分裂栅结构对碳化硅MOSFET重复雪崩应力诱导退化的抑制作用

标题 Suppression Effect of Split-Gate Structure on Repetitive Avalanche Stress Induced Degradation for SiC MOSFETs(TED 24年) 文章的研究内容 这篇文章的研究探讨了重复雪崩应力对碳化硅(SiC)MOSFET器件退化的影响&am…

JavaEE基础之- xml

目录 一、xml概述 1.什么是xml 2.W3C组织 3.XML的作用 4.XML与HTML比较 5.XML和properties(属性文件)比较 二、XML语法概述 1.文档展示 2.XML文档的组成部分 3.xml文档声明 3.1 什么是xml文档声明 3.2 xml文档声明结构 4.xml元素 4.1 xml元素的格…

网络运维学习笔记 012网工初级(HCIA-Datacom与CCNA-EI)某机构新增:GRE隧道与EBGP实施

文章目录 GRE隧道(通用路由封装,Generic Routing Encapsulation)协议号47实验:思科:开始实施: 华为:开始实施: eBGP实施思科:华为: GRE隧道(通用路…

RK3568开发板/电脑/ubuntu处于同一网段互通

1.查看无线局域网适配器WLAN winR输入cmd打开电脑终端输入ipconfig或arp -a查看无线局域网IP地址,这就是WIFI. 这里的IPv4是192.168.0.147,默认网关是192.168.0.1,根据网关地址配以太网IP, ubuntu的IP,和rk3568的IP。 且IP范围为192.168.…

如何调用 DeepSeek API:详细教程与示例

目录 一、准备工作 二、DeepSeek API 调用步骤 1. 选择 API 端点 2. 构建 API 请求 3. 发送请求并处理响应 三、Python 示例:调用 DeepSeek API 1. 安装依赖 2. 编写代码 3. 运行代码 四、常见问题及解决方法 1. API 调用返回 401 错误 2. API 调用返回…

如何在Jenkins上查看Junit报告

在 Jenkins 上查看 JUnit 报告通常有以下几个步骤: 构建结果页面: 首先,确保你的 Jenkins 构建已经执行完毕,并且成功生成了 JUnit 报告。前往 Jenkins 主页面,点击进入相应的项目或流水线。 构建记录: 选择你想查看的特定构建记…

【深度学习】计算机视觉(CV)-图像生成-风格迁移(Style Transfer)

风格迁移(Style Transfer) 风格迁移是一种计算机视觉技术,可以将一张图像的内容和另一张图像的风格融合在一起,生成一张既保留原始内容,又带有目标风格的全新图像!这种方法常用于艺术创作、图像增强、甚至…

在项目中调用本地Deepseek(接入本地Deepseek)

前言 之前发表的文章已经讲了如何本地部署Deepseek模型,并且如何给Deepseek模型投喂数据、搭建本地知识库,但大部分人不知道怎么应用,让自己的项目接入AI模型。 文末有彩蛋哦!!! 要接入本地部署的deepsee…

JAVA中常用类型

一、包装类 1.1 包装类简介 java是面向对象的语言,但是八大基本数据类型不符合面向对象的特征。因此为了弥补这种缺点,为这八中基本数据类型专门设计了八中符合面向面向对象的特征的类型,这八种具有面向对象特征的类型,就叫做包…

使用 GPTQ 进行 4 位 LLM 量化

权重量化方面的最新进展使我们能够在消费级硬件上运行大量大型语言模型,例如 RTX 3090 GPU 上的 LLaMA-30B 模型。这要归功于性能下降最小的新型 4 位量化技术,例如GPTQ、GGML和NF4。 在本文中,我们将探索流行的 GPTQ 算法,以了解…

【黑马点评优化】2-Canel实现多级缓存(Redis+Caffeine)同步

【黑马点评优化】2-Canel实现多级缓存(RedisCaffeine)同步 0 背景1 配置MySQL1.1 开启MySQL的binlog功能1.1.1 找到mysql配置文件my.ini的位置1.1.2 开启binlog 1.2 创建canal用户 2 下载配置canal2.1 canal 1.1.5下载2.2 配置canal2.3 启动canal2.4 测试…

CPP集群聊天服务器开发实践(五):nginx负载均衡配置

1 负载均衡器的原理与功能 单台Chatserver可以容纳大约两万台客户端同时在线聊天,为了提升并发量最直观的办法需要水平扩展服务器的数量,三台服务器可以容纳六万左右的客户端。 负载均衡器的作用: 把client的请求按照负载均衡算法分发到具体…

百问网(100ask)的IMX6ULL开发板的以太网控制器(MAC)与物理层(PHY)芯片(LAN8720A)连接的原理图分析(包含各引脚说明以及工作原理)

前言 本博文承接博文 https://blog.csdn.net/wenhao_ir/article/details/145663029 。 本博文和博文 https://blog.csdn.net/wenhao_ir/article/details/145663029 的目录是找出百问网(100ask)的IMX6ULL开发板与NXP官方提供的公板MCIMX6ULL-EVK(imx6ull14x14evk)在以太网硬件…

组件库地址

react: https://react-vant.3lang.dev/components/dialoghttps://react-vant.3lang.dev/components/dialog vue用v2的 Vant 2 - Mobile UI Components built on Vue

2025.2.16机器学习笔记:TimeGan文献阅读

2025.2.9周报 一、文献阅读题目信息摘要Abstract创新点网络架构一、嵌入函数二、恢复函数三、序列生成器四、序列判别器损失函数 实验结论后续展望 一、文献阅读 题目信息 题目: Time-series Generative Adversarial Networks会议: Neural Information…