卷积神经网络(CNN)模型 CIFAR-10 数据集 例子

news2024/12/26 3:53:21

使用 TensorFlow 构建一个简单的卷积神经网络(CNN)模型,完成对 CIFAR-10 数据集的图像分类任务。
使用自动编码器作为特征提取器,先通过自动编码器对图像数据进行降维,将图像从高维映射到低维特征空间,然后将提取的特征传入到 CNN 进行分类。
对比在不使用自动编码器特征提取的情况下,直接使用 CNN 进行分类的模型性能。

# 导入必要的库
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv2D, MaxPooling2D, Flatten

# 加载CIFAR - 10数据集
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# 预处理数据
x_train = x_train / 255.0
x_test = x_test / 255.0

# 构建自动编码器
input_img = Input(shape=(32, 32, 3))
# 编码器
encoded = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
encoded = MaxPooling2D((2, 2), padding='same')(encoded)
# 解码器
decoded = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
decoded = tf.keras.layers.UpSampling2D((2, 2))(decoded)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(decoded)
autoencoder = Model(input_img, decoded)

# 编译自动编码器
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# 训练自动编码器
autoencoder.fit(x_train, x_train,
                epochs=10,
                batch_size=128,
                validation_data=(x_test, x_test))

# 获取编码器部分
encoder = Model(input_img, encoded)

# 使用编码器提取特征
x_train_encoded = encoder.predict(x_train)
x_test_encoded = encoder.predict(x_test)

# 构建CNN分类器
input_features = Input(shape=x_train_encoded.shape[1:])
flatten = Flatten()(input_features)
dense1 = Dense(128, activation='relu')(flatten)
output = Dense(10, activation='softmax')(dense1)
classifier = Model(input_features, output)

# 编译CNN分类器
classifier.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 训练CNN分类器
classifier.fit(x_train_encoded, y_train,
               epochs=10,
               batch_size=128,
               validation_data=(x_test_encoded, y_test))

# 预测第1000个数据的类别(假设x_test是测试数据)
prediction = classifier.predict(x_test_encoded[999:1000])
predicted_class = tf.argmax(prediction, axis=1).numpy()[0]

Epoch 1/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - loss: 0.6075 - val_loss: 0.5594
Epoch 2/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5575 - val_loss: 0.5567
Epoch 3/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5555 - val_loss: 0.5559
Epoch 4/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5547 - val_loss: 0.5551
Epoch 5/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5544 - val_loss: 0.5547
Epoch 6/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5538 - val_loss: 0.5534
Epoch 7/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - loss: 0.5520 - val_loss: 0.5527
Epoch 8/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5519 - val_loss: 0.5525
Epoch 9/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5517 - val_loss: 0.5523
Epoch 10/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - loss: 0.5506 - val_loss: 0.5522
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 1ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 821us/step
Epoch 1/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.2917 - loss: 1.9827 - val_accuracy: 0.4199 - val_loss: 1.6372
Epoch 2/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4255 - loss: 1.6262 - val_accuracy: 0.4405 - val_loss: 1.5675
Epoch 3/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4576 - loss: 1.5361 - val_accuracy: 0.4749 - val_loss: 1.5088
Epoch 4/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4828 - loss: 1.4765 - val_accuracy: 0.5027 - val_loss: 1.4356
Epoch 5/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.4963 - loss: 1.4287 - val_accuracy: 0.5055 - val_loss: 1.4277
Epoch 6/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5055 - loss: 1.3981 - val_accuracy: 0.5067 - val_loss: 1.4073
Epoch 7/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5203 - loss: 1.3623 - val_accuracy: 0.5194 - val_loss: 1.3617
Epoch 8/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5338 - loss: 1.3217 - val_accuracy: 0.5246 - val_loss: 1.3555
Epoch 9/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5352 - loss: 1.3199 - val_accuracy: 0.5352 - val_loss: 1.3252
Epoch 10/10
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step - accuracy: 0.5379 - loss: 1.3143 - val_accuracy: 0.5144 - val_loss: 1.3900
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step

predicted_class

8

以下是使用 Python 语言结合 TensorFlow 库构建卷积神经网络(CNN)对 CIFAR-10 数据集进行图像分类,并获取第 1000 个数据预测类别的示例代码:

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np

# 加载CIFAR-10数据集
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# 归一化像素值到0-1范围
train_images, test_images = train_images / 255.0, test_images / 255.0

# 构建简单的CNN模型
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))  # 输出层,对应10个类别(CIFAR-10有10类)

# 编译模型
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 训练模型
model.fit(train_images, train_labels, epochs=10, 
          validation_data=(test_images, test_labels))

# 对测试集进行预测
predictions = model.predict(test_images)
# 获取预测的类别(取概率最大的类别索引作为预测类别)
predicted_classes = np.argmax(predictions, axis=1)

# 获取第1000个数据的预测类别
print(predicted_classes[999])

Epoch 1/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 6ms/step - accuracy: 0.3424 - loss: 1.7848 - val_accuracy: 0.5325 - val_loss: 1.2917
Epoch 2/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.5668 - loss: 1.2166 - val_accuracy: 0.6098 - val_loss: 1.1129
Epoch 3/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6391 - loss: 1.0275 - val_accuracy: 0.6399 - val_loss: 1.0059
Epoch 4/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 6ms/step - accuracy: 0.6676 - loss: 0.9352 - val_accuracy: 0.6690 - val_loss: 0.9239
Epoch 5/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7065 - loss: 0.8364 - val_accuracy: 0.6935 - val_loss: 0.8983
Epoch 6/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7272 - loss: 0.7781 - val_accuracy: 0.6914 - val_loss: 0.8845
Epoch 7/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7507 - loss: 0.7195 - val_accuracy: 0.6843 - val_loss: 0.9197
Epoch 8/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7609 - loss: 0.6772 - val_accuracy: 0.7024 - val_loss: 0.8741
Epoch 9/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7773 - loss: 0.6337 - val_accuracy: 0.7055 - val_loss: 0.8704
Epoch 10/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 6ms/step - accuracy: 0.7888 - loss: 0.5993 - val_accuracy: 0.7074 - val_loss: 0.8714
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step
8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265608.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot状态机

Spring Boot 状态机(State Machine)是 Spring Framework 提供的一种用于实现复杂业务逻辑的状态管理工具。它基于有限状态机(Finite State Machine, FSM)的概念,允许开发者定义一组状态、事件以及它们之间的转换规则。…

Redis基础知识分享(含5种数据类型介绍+增删改查操作)

一、redis基本介绍 1.redis的启动 服务端启动 pythonubuntu:~$ redis-server客户端启动 pythonubuntu:~$ redis-cli <127.0.0.1:6379> exit pythonubuntu:~$ redis-cli --raw //(支持中文的启动方式) <127.0.0.1:6379> exit2.redis基本操作 ping发送给服务器…

Pytorch注意力机制应用到具体网络方法(闭眼都会版)

文章目录 以YoloV4-tiny为例要加入的注意力机制代码模型中插入注意力机制 以YoloV4-tiny为例 解释一下各个部分&#xff1a; 最左边这部分为主干提取网络&#xff0c;功能为特征提取中间这边部分为FPN&#xff0c;功能是加强特征提取最后一部分为yolo head&#xff0c;功能为获…

交通控制系统中的 Prompt工程:引导LLMs实现高效交叉口管理 !

本研究提出了一种新型的交通控制系统方法&#xff0c;通过使用大型语言模型&#xff08;LLMs&#xff09;作为交通控制器。该研究利用它们的逻辑推理、场景理解和决策能力&#xff0c;实时优化通行能力并提供基于交通状况的反馈。LLMs将传统的分散式交通控制过程集中化&#xf…

产品升级!Science子刊同款ARGs-HOST分析,get!

凌恩生物明星chanpin 抗性宏基因-宿主分析 Science子刊同款分析 数据挖掘更进一步&#xff01; 抗生素的大量使用与滥用使微生物体内编码抗生素抗性的基因在环境中选择性富集&#xff0c;致病菌通过基因突变或者水平基因转移获得抗生素抗性基因后&#xff0c;导致抗生素治疗…

Python8-写一些小作业

记录python学习&#xff0c;直到学会基本的爬虫&#xff0c;使用python搭建接口自动化测试就算学会了&#xff0c;在进阶webui自动化&#xff0c;app自动化 python基础8-灵活运用顺序、选择、循环结构 写一些小练习题目1、给一个半径&#xff0c;求圆的面积和周长&#xff0c;…

四相机设计实现全向视觉感知的开源空中机器人无人机

开源空中机器人 基于深度学习的OmniNxt全向视觉算法OAK-4p-New 全景硬件同步相机 机器人的纯视觉避障定位建图一直是个难题&#xff1a; 系统实现复杂 纯视觉稳定性不高 很难选到实用的视觉传感器 为此多数厂家还是采用激光雷达的定位方案。 OAK-4p-New 为了弥合这一差距…

Diagramming AI: 使用自然语言来生成各种工具图

前言 在画一些工具图时&#xff08;流程图、UML图、架构图&#xff09;&#xff0c;你还在往画布上一个个的拖拽组件来进行绘制么&#xff1f;今天介绍一款更有效率的画图工具&#xff0c;它能够通过简单的自然语言描述来完成一个个复杂的图。 首页 进入官网之后&#xff0c;我…

springboot启动不了 因一个spring-boot-starter-web底下的tomcat-embed-core依赖丢失

这个包丢失了 启动不了 起因是pom中加入了 <tomcat.version></tomcat.version>版本指定&#xff0c;然后idea自动编译后&#xff0c;包丢了&#xff0c;删除这个配置后再也找不回来&#xff0c; 这个包正常在 <dependency><groupId>org.springframe…

“笃威尔数字技术”受邀出席2024 H-Tech Data创新情报论坛!

​ 2024年12月20日&#xff0c;以“创新情报 向新而行”为主题的2024 H-Tech Data创新情报论坛暨创新情报专业委员会成立仪式在深圳成功举办。本次大会由中国科学技术情报学会主办&#xff0c;由深圳国家高新技术产业创新中心牵头承办&#xff0c;旨在围绕技术赋能、场景应用、…

Android Studio 的革命性更新:Project Quartz 和 Gemini,开启 AI 开发新时代!

&#x1f31f; Android Studio 的革命性更新&#xff1a;Project Quartz 和 Gemini&#xff0c;开启 AI 开发新时代&#xff01; 在这个技术飞速发展的时代&#xff0c;Android 开发者们迎来了两项重大更新&#xff1a;Project Quartz 和 Gemini。这不仅仅是更新&#xff0c;而…

kkfileview代理配置,Vue对接kkfileview实现图片word、excel、pdf预览

kkfileview部署 官网&#xff1a;https://kkfileview.keking.cn/zh-cn/docs/production.html 这个是官网部署网址&#xff0c;这里推荐大家使用docker镜像部署&#xff0c;因为我是直接找运维部署的&#xff0c;所以这里我就不多说明了&#xff0c;主要说下nginx代理配置&am…

RT-DETR学习笔记(2)

七、IOU-aware query selection 下图是原始DETR。content query 是初始化为0的label embedding, position query 是通过nn.Embedding初始化的一个嵌入矩阵&#xff0c;这两部分没有任何的先验信息&#xff0c;导致DETR的收敛慢。 RT-DETR则提出要给这两部分&#xff08;conten…

iOS 苹果开发者账号: 查看和添加设备UUID 及设备数量

参考链接&#xff1a;苹果开发者账号下添加新设备UUID - 简书 如果要添加新设备到 Profiles 证书里&#xff1a; 1.登录开发者中心 Sign In - Apple 2.找到证书设置&#xff1a; Certificate&#xff0c;Identifiers&Profiles > Profiles > 选择对应证书 edit &g…

汽车IVI中控开发入门及进阶(47):CarPlay开发

概述: 车载信息娱乐(IVI)系统已经从仅仅播放音乐的设备发展成为现代车辆的核心部件。除了播放音乐,IVI系统还为驾驶员提供导航、通信、空调、电源配置、油耗性能、剩余行驶里程、节能建议和许多其他功能。 ​ 驾驶座逐渐变成了你家和工作场所之外的额外生活空间。2014年,…

Oracle、ACCSEE与TDMS的区别

Oracle、ACCSEE和TDMS都是不同类型的数据管理和存储工具&#xff0c;它们各自有独特的用途、结构和复杂性。Oracle是一个功能强大的关系型数据库管理系统&#xff0c;适用于大规模企业级应用&#xff0c;支持复杂查询和事务管理。ACCSEE主要应用于实时数据采集和过程监控&#…

商场消防电气控制系统设计(论文+源码)

1系统的功能及方案设计 如图2.1所示为本次设计的整体框图&#xff0c;其中单片机部分采用ST89C52来负责协调各个模块&#xff1b;液晶选择LCD1602液晶屏来显示信息;温度传感器选择PT1000进行温度的检测&#xff1b;烟雾传检测选择MQ2烟雾传感器&#xff1b;CO2检测选择CCS811模…

7. petalinux 根文件系统配置(package group)

根文件系统配置&#xff08;Petalinux package group&#xff09; 当使能某个软件包组的时候&#xff0c;依赖的包也会相应被使能&#xff0c;解决依赖问题&#xff0c;在配置页面的help选项可以查看需要安装的包 每个软件包组的功能: packagegroup-petalinux-audio包含与音…

2024年12月一区SCI-加权平均优化算法Weighted average algorithm-附Matlab免费代码

引言 本期介绍了一种基于加权平均位置概念的元启发式优化算法&#xff0c;称为加权平均优化算法Weighted average algorithm&#xff0c;WAA。该成果于2024年12月最新发表在中JCR1区、 中科院1区 SCI期刊 Knowledge-Based Systems。 在WAA算法中&#xff0c;加权平均位置代表当…

操作系统(23)外存的存储空间的管理

一、外存的基本概念与特点 定义&#xff1a;外存&#xff0c;也称为辅助存储器&#xff0c;是计算机系统中用于长期存储数据的设备&#xff0c;如硬盘、光盘、U盘等。与内存相比&#xff0c;外存的存储容量大、成本低&#xff0c;但访问速度相对较慢。特点&#xff1a;外存能够…