训练 CNN 对 CIFAR-10 数据中的图像进行分类

news2024/12/25 12:31:11

1. 加载 CIFAR-10 数据库

import keras
from keras.datasets import cifar10

# 加载预先处理的训练数据和测试数据
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

2. 可视化前 24 个训练图像

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure(figsize=(20,5))
for i in range(36):
    ax = fig.add_subplot(3, 12, i + 1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(x_train[i]))

3. 通过将每幅图像中的每个像素除以 255 来调整图像比例

事实上,代价函数的形状是一个碗,但如果特征的比例非常不同,它也可能是一个拉长的碗。下图显示了梯度下降法在特征 1 和特征 2 比例相同的训练集上的应用(左图),以及在特征 1 的值远小于特征 2 的训练集上的应用(右图)。

Tips : 使用梯度下降法时,应确保所有特征的比例相似,以加快训练速度,否则收敛时间会更长。

# rescale [0,255] --> [0,1]
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

 4. 将数据集分为训练集、测试集和验证集

from keras.utils import to_categorical

# 对标签进行一次热编码
num_classes = len(np.unique(y_train))
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# 将训练集分为训练集和验证集
(x_train, x_valid) = x_train[5000:], x_train[:5000]
(y_train, y_valid) = y_train[5000:], y_train[:5000]

# 打印训练集的形状
print('x_train shape:', x_train.shape)

# 打印训练、验证和测试图像的数量
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_valid.shape[0], 'validation samples')

5. 定义模型架构 

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

model = Sequential()
model.add(Conv2D(filters=16, kernel_size=2, padding='same', activation='relu', 
                        input_shape=(32, 32, 3)))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))

model.summary()

6. 编译模型 

# compile the model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

7. 训练模型

from keras.callbacks import ModelCheckpoint   

# 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)

hist = model.fit(x_train, y_train, batch_size=32, epochs=100,
          validation_data=(x_valid, y_valid), callbacks=[checkpointer], 
          verbose=2, shuffle=True)

8. 加载验证精度最高的模型

# 加载验证精度最高的权重
model.load_weights('model.weights.best.hdf5')

 9. 计算测试集的分类精度

# 评估和打印测试精度
score = model.evaluate(x_test, y_test, verbose=0)
print('\n', 'Test accuracy:', score[1])

10. 可视化一些预测

这可能会让你对网络错误分类某些对象的原因有所了解。

# 在测试集上得到预测
y_hat = model.predict(x_test)

# 定义文本标签 (source: https://www.cs.toronto.edu/~kriz/cifar.html)
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 绘制测试图像的随机样本、预测标签和基本真实图像
fig = plt.figure(figsize=(20, 8))
for i, idx in enumerate(np.random.choice(x_test.shape[0], size=32, replace=False)):
    ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(x_test[idx]))
    pred_idx = np.argmax(y_hat[idx])
    true_idx = np.argmax(y_test[idx])
    ax.set_title("{} ({})".format(cifar10_labels[pred_idx], cifar10_labels[true_idx]),
                 color=("green" if pred_idx == true_idx else "red"))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1258898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java开发基础】intellij IDEA快速配置JDBC驱动连接MySQL数据库并查询数据,其实真的很简单,我5分钟就学会了!

🚀 个人主页 极客小俊 ✍🏻 作者简介:web开发者、设计师、技术分享博主 🐋 希望大家多多支持一下, 我们一起学习和进步!😄 🏅 如果文章对你有帮助的话,欢迎评论 💬点赞&a…

多功能音乐沙漏的设计与实现

【摘要】随着当今社会快节奏生活的发展,当代大学生越来忽视时间管理的重要性,在原本计划只看几个视频只玩几个游戏的碎片化娱乐中耗费了大量的时光,对于自己原本的学习生活产生了巨大的影响。为更加有效的反映时间的流逝,特设计该…

canvas基础:绘制线段,绘制多边形

canvas实例应用100 专栏提供canvas的基础知识,高级动画,相关应用扩展等信息。 canvas作为html的一部分,是图像图标地图可视化的一个重要的基础,学好了canvas,在其他的一些应用上将会起到非常重要的帮助。 文章目录 使用…

TDA4开发环境Docker化

文章目录 背景1. TDA4X Linux SDK编译环境镜像构建1.1 安装SDK1.2 验证制卡1.2.1 出现的问题:1.3 验证编译1.3.1 出现的问题2. TDA4X Linux-RT SDK编译环境镜像构建2.1 安装SDK2.2 出现的问题参考背景 开始阅读本篇前,假设你已经对docker有了一定了解,且有过docker换件搭建…

优思学院|5S不只是清洁,但却离不开清洁!

很多说5S不止是清洁和搞卫生那么简单,相信有正规地学习过5S的人都应该深切了解。 不过,5S之中的确包括了清理、清洁的步骤,5S,也被称为“五常法则”或“五常法”,它包含了: 整理(SEIRI&#x…

8.统一异常处理 + 统一记录日志

目录 1.统一异常处理 2.统一记录日志 1.统一异常处理 在 HomeController 类中添加请求方法(服务器发生异常之后需要统一处理异常,记录日志,然后转到 500 页面,需要人工处理重定向到 500 页面,提前把 500 页面请求访问…

Containerd Container管理功能解析

Containerd Container管理功能解析 container是containerd的一个核心功能,用于创建和管理容器的基本信息。 本篇containerd版本为v1.7.9。 更多文章访问 https://www.cyisme.top 本文从ctr c create命令出发,分析containerd的容器及镜像管理相关功能。 …

01 项目架构

关于我 曾经就职于蚂蚁金服,多年的后端开发经验,对微服务、架构这块研究颇深,同时也是一名热衷于技术分享、拥抱开源技术的博主。 个人技术公众号:码猿技术专栏个人博客:www.java-family.cn 前期一直在更新《Spring…

什么是美颜sdk?视频直播美颜sdk技术深度剖析

美颜sdk可以通过实时处理图像,提升主播或用户在视频直播中的外观。通过美颜sdk接口调用可以轻松实现美颜效果。美颜sdk的核心目标是在保持图像真实性的同时,为用户创造出最理想的美化效果。 一、美颜sdk的技术实现 1.面部识别技术:美颜sdk…

虚拟直播在文旅行业火爆发展!背后的“生意经”你抓住了吗?

自疫情结束以来,文化和旅游行业恢复势头强劲,各大旅游景点消费活跃度持续攀升。在这种情况下,“直播文旅”模式的深度融合对文旅行业的客流导入起到了极大的带动作用。 不过,当前的文旅直播也出现了一些问题&#xf…

流媒体播放器EasyPlayer播放H.265与H.264时进度条样式异常该如何解决?

H5无插件流媒体播放器EasyPlayer属于一款高效、精炼、稳定且免费的流媒体播放器,可支持多种流媒体协议播放,可支持H.264与H.265编码格式,性能稳定、播放流畅,能支持WebSocket-FLV、HTTP-FLV,HLS(m3u8&#…

北京 | 竹云与南方电网携手荣获“IDC 2023未来企业奖未来连接领导者”

11月22日-23日,2023第八届IDC中国数字化转型年度盛典在北京召开。本次大会以“竞放数字力量”为主题,汇聚超过1000位来自不同行业的大咖与伙伴共同参与此次盛会,从全球化视角出发,围绕本土化落地人工智能(大模型&#…

pytest系列——pytest_collection_modifyitems钩子函数修改测试用例执行顺序

前言 pytest默认执行用例是根据项目下的文件名称按ascii码去收集运行的;文件中的用例是从上往下按顺序执行的。 pytest_collection_modifyitems 这个函数顾名思义就是收集测试用例、改变用例的执行顺序的。 【严格意义上来说,我们在用例设计原则上用例…

NFTScan | 11.20~11.26 NFT 市场热点汇总

欢迎来到由 NFT 基础设施 NFTScan 出品的 NFT 生态热点事件每周汇总。 周期:2023.11.20~ 2023.11.26 NFT Hot News 01/ OKX Ordinals 市场 API 完成升级 11 月 21 日,OKX Ordinals 市场 API 现已完成升级,新增支持按币种单价查询、排序&…

TDA4VM EVM开发板调试笔记

文章目录 1. 前言2. 官网资料导读3. 安装 Linux SDK4. 制作SD 启动卡5. 验证启动1. 前言 TDA4作为一般经典的车规级SOC芯片,基于它的低阶智驾方案目前成为各家智驾方案公司的量产首选,这也使得基于TDA4的开发需求陡增,开发和使用TDA4既要熟悉Linux驱应用开发,还要熟悉传统…

uniapp基础-教程之HBuilderX配置篇-01

uniapp教程之HBuilderX配置篇-01 为什么要做这个教程的梳理,主要用于自己学习和总结,利于增加自己的积累和记忆。首先下载HBuilderX,并保证你的软件在C盘进行运行,最好使用英文或者拼音,这个操作是为了保证软件的稳定…

瑞芯微RK3588核心板在网络录像机产品中的应用与优势-迅为电子

你是否曾经想过,网络录像机是如何工作的?为什么它可以在没有电脑的情况下进行录像和监控?答案是,它依赖于RK3588核心板。那么,RK3588核心板到底在网络录像机产品中扮演着什么样的角色呢?它又具有哪些优势呢…

力扣:239. 滑动窗口最大值

题目&#xff1a; 给定一个数组 nums&#xff0c;有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回滑动窗口中的最大值。 提示&#xff1a; 1 < nums.length < 10^5-10^4 < n…

使用UI Automation库用于UI自动化测试

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…

NX二次开发UF_CURVE_ask_spline_data 函数介绍

文章作者&#xff1a;里海 来源网站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_ask_spline_data Defined in: uf_curve.h int UF_CURVE_ask_spline_data(tag_t spline_tag, UF_CURVE_spline_p_t spline_data ) overview 概述 Reads the spline data a…