训练 CNN 对 CIFAR-10 数据中的图像进行分类-keras实现

news2025/1/10 23:32:23

1. 加载 CIFAR-10 数据库

import keras
from keras.datasets import cifar10

# 加载预先处理的训练数据和测试数据
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

2. 可视化前 24 个训练图像

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure(figsize=(20,5))
for i in range(36):
    ax = fig.add_subplot(3, 12, i + 1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(x_train[i]))

3. 通过将每幅图像中的每个像素除以 255 来调整图像比例

事实上,代价函数的形状是一个碗,但如果特征的比例非常不同,它也可能是一个拉长的碗。下图显示了梯度下降法在特征 1 和特征 2 比例相同的训练集上的应用(左图),以及在特征 1 的值远小于特征 2 的训练集上的应用(右图)。

Tips : 使用梯度下降法时,应确保所有特征的比例相似,以加快训练速度,否则收敛时间会更长。

# rescale [0,255] --> [0,1]
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

 4. 将数据集分为训练集、测试集和验证集

from keras.utils import to_categorical

# 对标签进行一次热编码
num_classes = len(np.unique(y_train))
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# 将训练集分为训练集和验证集
(x_train, x_valid) = x_train[5000:], x_train[:5000]
(y_train, y_valid) = y_train[5000:], y_train[:5000]

# 打印训练集的形状
print('x_train shape:', x_train.shape)

# 打印训练、验证和测试图像的数量
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_valid.shape[0], 'validation samples')

5. 定义模型架构 

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

model = Sequential()
model.add(Conv2D(filters=16, kernel_size=2, padding='same', activation='relu', 
                        input_shape=(32, 32, 3)))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))

model.summary()

6. 编译模型 

# compile the model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

7. 训练模型

from keras.callbacks import ModelCheckpoint   

# 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)

hist = model.fit(x_train, y_train, batch_size=32, epochs=100,
          validation_data=(x_valid, y_valid), callbacks=[checkpointer], 
          verbose=2, shuffle=True)

8. 加载验证精度最高的模型

# 加载验证精度最高的权重
model.load_weights('model.weights.best.hdf5')

 9. 计算测试集的分类精度

# 评估和打印测试精度
score = model.evaluate(x_test, y_test, verbose=0)
print('\n', 'Test accuracy:', score[1])

10. 可视化一些预测

这可能会让你对网络错误分类某些对象的原因有所了解。

# 在测试集上得到预测
y_hat = model.predict(x_test)

# 定义文本标签 (source: https://www.cs.toronto.edu/~kriz/cifar.html)
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 绘制测试图像的随机样本、预测标签和基本真实图像
fig = plt.figure(figsize=(20, 8))
for i, idx in enumerate(np.random.choice(x_test.shape[0], size=32, replace=False)):
    ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(x_test[idx]))
    pred_idx = np.argmax(y_hat[idx])
    true_idx = np.argmax(y_test[idx])
    ax.set_title("{} ({})".format(cifar10_labels[pred_idx], cifar10_labels[true_idx]),
                 color=("green" if pred_idx == true_idx else "red"))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1275054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

桶装水订水送水小程序具备以下主要功能

桶装水订水送水小程序具备以下主要功能: 对比传统的电话订水,订水小程序展现出显著的优势: 1. 便捷性:用户通过小程序就能轻松预订水桶,无需亲自出门,极大提升了生活的便捷度。 2. 即时性:送水…

element-ui表格滚动效果,el-table滚动条样式重置

项目首页需要展示一个表格滚动区域&#xff0c;特此来记录一下 HTML <div class"table-box" mouseenter"mouseenter" mouseleave"mouseleave"><el-table :data"tableList" border height"400px" v-loading"…

2023_Spark_实验二十四:Kafka集群环境搭建

Kafka集群环境搭建 一、环境说明 二、安装步骤 一、环境说明 目前的Kafka版本还是需要借助zookeeper来存储cluster、brokers、consumer等相关元信息&#xff0c;在当前版本即 在本案例中&#xff0c;我们采用了外部的zookeeper&#xff0c;即搭建了三节点的集群zookeeper环境…

Python Scrapy分布式爬虫

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com 在当今信息爆炸的时代&#xff0c;获取大规模数据对于许多应用至关重要。而分布式爬虫作为一种强大的工具&#xff0c;在处理大量数据采集和高效爬取方面展现了卓越的能力。 本文将深入探讨分布式爬虫的实际应用…

最新Graphviz python安装教程及使用

文章目录 Graphviz 安装python安装graphviz库 Graphviz 安装 Graphviz是一个独立的软件&#xff0c;在用python的pip下载之前&#xff0c;需要先下载软件。 网址&#xff1a;https://graphviz.org/download/ 找到合适的版本进行下载安装。记住自己的安装位置&#xff0c;完…

如何通过数据文化加速企业管理的转型升级?

#01 企业管理 更需要“转型升级” 中国企业管理在某种程度上来看&#xff0c;受到中国传统文化、社会价值观及现代化趋势等多方面影响的结果&#xff0c;比如说&#xff0c;中国传统文化强调长期思考和计划&#xff0c;这在企业管理中体现为对长期业务发展和可持续的关注。但…

小白备战蓝桥杯:Java常用API

一、什么是API 就是别人写好的一些类&#xff0c;给咱们程序员直接拿去调用即可解决问题的 我们之前接触过的Scanner和Random都是API 但java中提供的API很多&#xff0c;我们没有必要去学习所有的API&#xff0c;只需要知道一些常用的API&#xff0c;再借助帮助文档去使用AP…

蓝桥杯每日一题2023.12.1

题目描述 蓝桥杯大赛历届真题 - C 语言 B 组 - 蓝桥云课 (lanqiao.cn) 题目分析 对于此题目而言思路较为重要&#xff0c;实际可以转化为求两个数字对应的操作&#xff0c;输出最前面的数字即可 #include<bits/stdc.h> using namespace std; int main() {for(int i 1…

【前缀和]LeetCode1862:向下取整数对和

本文涉及的基础知识点 C算法&#xff1a;前缀和、前缀乘积、前缀异或的原理、源码及测试用例 包括课程视频 作者推荐 动态规划LeetCode2552&#xff1a;优化了6版的1324模式 题目 给你一个整数数组 nums &#xff0c;请你返回所有下标对 0 < i, j < nums.length 的 …

steam搬砖项目怎么赚钱?附详细拆解教程

互联网上的赚钱项目如林&#xff0c;为何他人能斩获成果&#xff0c;你却只能望洋兴叹&#xff1f;归根结底&#xff0c;还是因为你的心态过于急功近利。今天&#xff0c;我要为你揭示的Steam实体项目&#xff0c;是专门为热门竞技游戏CS:GO量身打造的。CS:GO&#xff0c;这款风…

自动机器学习AutoML

自动机器学习AutoML AutoMLAuto-SklearnAutoKerasAutoGluonGoogle AutoMLAzure自动机器学习 AutoML 模型的选择和超参数的调节等等任务对于机器学习算法的开发者来说是一件繁琐的工作&#xff0c;为了使得机器可以自动地设计模型并调优&#xff0c;自动机器学习AutoML便应运而…

打破语言障碍:跨境电商中的多语言营销策略

随着全球市场的不断扩大&#xff0c;跨境电商成为企业拓展国际业务的重要途径。然而&#xff0c;语言障碍往往成为企业在跨境电商中面临的挑战之一。为了打破这一障碍&#xff0c;实现全球市场的可持续发展&#xff0c;多语言营销策略变得至关重要。 多语言市场的挑战 在跨境电…

春秋云境:CVE-2022-32991(sql注入)

靶标介绍&#xff1a; 该CMS的welcome.php中存在SQL注入攻击。 获取登录地址 http://eci-2zeb0096que0556y47vq.cloudeci1.ichunqiu.com:80 登录注册 注册成功登录进入注册接口 参数接口一 发现接口参数q http://eci-2zeb0096que0556y47vq.cloudeci1.ichunqiu.com/welcome.p…

使用visual Studio MFC 平台实现对灰度图添加椒盐噪声,并进行均值滤波与中值滤波

平滑处理–滤波 本文使用visual Studio MFC 平台实现对灰度图添加椒盐噪声&#xff0c;并进行均值滤波与中值滤波 关于其他MFC单文档工程可参考 01-Visual Studio 使用MFC 单文档工程绘制单一颜色直线和绘制渐变颜色的直线 02-visual Studio MFC 绘制单一颜色三角形、渐变颜色边…

【openssl】Window系统如何编译openssl

本文主要记录如何编译出windows版本的openss的lib库 1.下载openssl&#xff0c;获得openssl-master.zip。 a.可以通过github&#xff08;网址在下方&#xff09;上下载最新的代码、今天是2023.12.1我用的master版本&#xff0c;下载之后恭喜大侠获得《openssl-master.zip》 …

iPhone苹果手机如何将词令网页添加到苹果iPhone手机桌面快捷打开?

iPhone苹果手机如何将词令网页添加到苹果iPhone手机桌面快捷打开&#xff1f; 1、在iPhone苹果手机上找到「Safari浏览器」,并点击打开&#xff1b; 2、打开Safari浏览器后&#xff0c;输入词令官方网站地址&#xff1a;ciling.cn ; 3、打开词令官网后&#xff0c;点击Safari…

特殊二叉树——堆

&#x1f308;一、堆的基本概念 1.堆&#xff1a;非线性结构&#xff0c;是完全二叉树 2.堆分为大堆和小堆。 大堆&#xff1a;树中任意一个父亲都大于等于孩子&#xff0c;根节点值大于等于其所有子孙节点的值。 小堆&#xff1a;树中任意一个父亲都小于等于孩子&#xff0c;…

IO / day01 作业。

1.使用fgets统计一个文件的行号 //使用fgets统计一个文件的行号#include <string.h> #include <stdlib.h> #include <stdio.h>int main(int argc, const char *argv[]) {if(argc<2) //获取文件名{printf("input error\n!");printf("usage…

Flutter页面刷新失败?看看是不是这个原因

文章目录 问题描述解决办法在控件A中定义回调函数在页面中使用控件A 原因分析回顾问题原因分析 setState使用注意事项上下文正确性异步更新避免深层嵌套避免频繁调用避免在 build 方法中调用避免在 dispose 方法中调用 问题描述 我用flutter开发了一个页面&#xff0c;页面上有…

苍穹外卖——删除购物车信息

1. 删除购物车中一个商品 1.1 产品原型 1.2 接口设计 1.3 数据模型 shopping_cart表&#xff1a; -- auto-generated definition create table shopping_cart (id bigint auto_increment comment 主键primary key,name varchar(32) null comment 商品名称…