第T2周:彩色图片分类

news2025/1/11 20:53:26
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

👉 要求:

  • 学习如何编写一个完整的深度学习程序
  • 了解分类彩色图片会灰度图片有什么区别
  • 测试集accuracy到达72%

🦾我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Lab
  • 深度学习环境:
    • TensorFlow2

一、 前期准备

1.1. 设置GPU

  • 如果设备上支持GPU就使用GPU,否则使用CPU
  • Mac上的GPU使用mps
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

gpu0
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')

1.2. 导入数据

使用dataset下载MNIST数据集,并划分好训练集与测试集

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce so we will re-download the data.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 8500s 50us/step

1.3. 归一化

数据归一化作用

● 使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确。
● 加快学习算法的收敛速度。

更详解的介绍请参考文章:🔗归一化与标准化

# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0

# 查看数据维数信息
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
((50000, 32, 32, 3), (10000, 32, 32, 3), (50000, 1), (10000, 1))

1.4. 可视化图片

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(20,10))
for i in range(20):
    plt.subplot(5,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i][0]])
plt.show()

在这里插入图片描述

二、构建简单的CNN网络

⭐池化层

池化层对提取到的特征信息进行降维,一方面使特征图变小,简化网络计算复杂度;另一方面进行特征压缩,提取主要特征,增加平移不变性,减少过拟合风险。但其实池化更多程度上是一种计算性能的一个妥协,强硬地压缩特征的同时也损失了一部分信息,所以现在的网络比较少用池化层或者使用优化后的如SoftPool。

池化层包括最大池化层(MaxPooling)和平均池化层(AveragePooling),均值池化对背景保留更好,最大池化对纹理提取更好)。同卷积计算,池化层计算窗口内的平均值或者最大值。例如通过一个 2*2 的最大池化层,其计算方式如下:
在这里插入图片描述

我们即将构建模型的结构图,我以分别二维和三维的形式展示出来方便大家理解。

  • 平面结构图
    在这里插入图片描述

  • 立体结构图
    在这里插入图片描述

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), #卷积层1,卷积核3*3
    layers.MaxPooling2D((2, 2)),                   #池化层1,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层2,卷积核3*3
    layers.MaxPooling2D((2, 2)),                   #池化层2,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层3,卷积核3*3
    
    layers.Flatten(),                      #Flatten层,连接卷积层与全连接层
    layers.Dense(64, activation='relu'),   #全连接层,特征进一步提取
    layers.Dense(10)                       #输出层,输出预期结果
])

model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 30, 30, 32)        896       
                                                                 
 max_pooling2d (MaxPooling2  (None, 15, 15, 32)        0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 13, 13, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 6, 6, 64)          0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 4, 4, 64)          36928     
                                                                 
 flatten (Flatten)           (None, 1024)              0         
                                                                 
 dense (Dense)               (None, 64)                65600     
                                                                 
 dense_1 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 122570 (478.79 KB)
Trainable params: 122570 (478.79 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


2024-06-23 22:16:01.054779: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-06-23 22:16:01.054802: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-06-23 22:16:01.054811: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-06-23 22:16:01.054984: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-06-23 22:16:01.055316: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)

三、编译模型

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

四、训练模型

history = model.fit(train_images, train_labels, epochs=10, 
                    validation_data=(test_images, test_labels))
Epoch 1/10


2024-06-23 22:16:41.825293: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


1563/1563 [==============================] - ETA: 0s - loss: 1.5781 - accuracy: 0.4242

2024-06-23 22:16:54.304550: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


1563/1563 [==============================] - 13s 8ms/step - loss: 1.5781 - accuracy: 0.4242 - val_loss: 1.3528 - val_accuracy: 0.5133
Epoch 2/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.2892 - accuracy: 0.5464 - val_loss: 1.2880 - val_accuracy: 0.5617
Epoch 3/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.3585 - accuracy: 0.5521 - val_loss: 1.6484 - val_accuracy: 0.5155
Epoch 4/10
1563/1563 [==============================] - 12s 8ms/step - loss: 2.0448 - accuracy: 0.5044 - val_loss: 3.0545 - val_accuracy: 0.4380
Epoch 5/10
1563/1563 [==============================] - 12s 8ms/step - loss: 5.7139 - accuracy: 0.4563 - val_loss: 20.7035 - val_accuracy: 0.2908
Epoch 6/10
1563/1563 [==============================] - 12s 8ms/step - loss: 45.9029 - accuracy: 0.3672 - val_loss: 109.2576 - val_accuracy: 0.3624
Epoch 7/10
1563/1563 [==============================] - 12s 8ms/step - loss: 504.0281 - accuracy: 0.2838 - val_loss: 1375.9681 - val_accuracy: 0.2399
Epoch 8/10
1563/1563 [==============================] - 12s 8ms/step - loss: 3719.2263 - accuracy: 0.2359 - val_loss: 6212.4688 - val_accuracy: 0.2268
Epoch 9/10
1563/1563 [==============================] - 12s 8ms/step - loss: 11472.0957 - accuracy: 0.2238 - val_loss: 20005.8828 - val_accuracy: 0.1773
Epoch 10/10
1563/1563 [==============================] - 12s 8ms/step - loss: 25618.4004 - accuracy: 0.2182 - val_loss: 31095.4336 - val_accuracy: 0.2160

五、预测

通过模型进行预测得到的是每一个类别的概率,数字越大该图片为该类别的可能性越大

plt.imshow(test_images[1])

在这里插入图片描述

输出测试集中第一张图片的预测结果

import numpy as np

pre = model.predict(test_images)
print(class_names[np.argmax(pre[1])])
 75/313 [======>.......................] - ETA: 0s

2024-06-23 22:20:12.257425: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


313/313 [==============================] - 1s 3ms/step
ship

六、模型评估

import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

在这里插入图片描述

print(test_acc)
0.6845156432345124

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1854507.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QT事件处理系统之五:自定义事件的发送案例 sendEvent和postEvent接口

1、案例 双击窗口,会发送 自定义事件,然后在事件过滤中心进行拦截处理自定义事件。 2、核心代码 /*解释:双击窗口时,将产生双击事件,然后该事件被包裹成一个对象,随后将会被发往event事件中心,然后进行事件的处理(Widget对象);因为m_lineEdit开启了事件过滤机制,所…

【UML用户指南】-21-对基本行为建模-活动图

目录 1、概念 2、组成结构 2.1、动作 2.2、活动节点 2.3、控制流 2.4、分支 2.5、分岔和汇合 2.6、泳道 2.7、对象流 2.8、扩展区域 3、一般用法 3.1、对工作流建模 3.2、对操作建模 一个活动图从本质上说是一个流程图&#xff0c;展现从活动到活动的控制流 活动图…

图像编辑技术的新篇章:基于扩散模型的综述

在人工智能的浪潮中&#xff0c;图像编辑技术正经历着前所未有的变革。随着数字媒体、广告、娱乐和科学研究等领域对高质量图像编辑需求的不断增长&#xff0c;传统的图像编辑方法已逐渐无法满足日益复杂的视觉内容创作需求。尤其是在AI生成内容&#xff08;AIGC&#xff09;的…

【论文复现|智能算法改进】一种基于多策略改进的鲸鱼算法

目录 1.算法原理2.改进点3.结果展示4.参考文献5.代码获取 1.算法原理 SCI二区|鲸鱼优化算法&#xff08;WOA&#xff09;原理及实现【附完整Matlab代码】 2.改进点 混沌反向学习策略 将混沌映射和反向学习策略结合&#xff0c;形成混沌反向学习方法&#xff0c;通过该方 法…

三十八篇:架构大师之路:探索软件设计的无限可能

架构大师之路&#xff1a;探索软件设计的无限可能 1. 引言&#xff1a;架构的艺术与科学 在软件工程的广阔天地中&#xff0c;系统架构不仅是设计的骨架&#xff0c;更是灵魂所在。它如同建筑师手中的蓝图&#xff0c;决定了系统的结构、性能、可维护性以及未来的扩展性。本节…

LSSS算法实现,基于eigen和pbc密码库【一文搞懂LSSS,原理+代码】

文章目录 一. LSSS简介1.1 概述1.2 线性秘密分享方案&#xff08;LSSS&#xff09;与 Shamir的秘密分享方案对比LSSS1.2.1 Shamir的秘密分享方案1.2.2 线性秘密分享方案&#xff08;LSSS&#xff09;1.2.3 主要区别 二. 基于矩阵的LSSS加解密原理分析2.1 LSSS矩阵构造2.1.1 定义…

【python】python基于微博互动数据的用户类型预测(随机森林与支持向量机的比较分析)(源码+数据集+课程论文)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

群体优化算法---电磁共振优化算法(EROA)介绍包含示例滤波器设计

介绍 电磁共振优化算法&#xff08;Electromagnetic Resonance Optimization Algorithm, EROA&#xff09;是一种新型的元启发式优化算法&#xff0c;其灵感来源于电磁共振现象。电磁共振是一种物理现象&#xff0c;当一个系统在特定频率下响应最大时&#xff0c;这个频率被称…

【算法】优先级队列-基础与应用

优先级队列&#xff08;Priority Queue&#xff09;是一种特殊的队列类型&#xff0c;它允许在其元素中分配优先级。与传统的先进先出&#xff08;FIFO&#xff09;队列不同&#xff0c;优先级队列中元素的出队顺序取决于它们的优先级。优先级较高的元素会被优先处理&#xff0…

LLaMA:挑战大模型Scaling Law的性能突破

实际问题 在大模型的研发中,通常会有下面一些需求: 计划训练一个10B的模型,想知道至少需要多大的数据?收集到了1T的数据,想知道能训练一个多大的模型?老板准备1个月后开发布会,给的资源是100张A100,应该用多少数据训多大的模型效果最好?老板对现在10B的模型不满意,想…

完美解决找不到steam_api64.dll无法执行代码问题

游戏缺失steam_api64.dll通常意味着该游戏依赖于Steam平台的一些功能或服务&#xff0c;而这个DLL文件是Steam客户端的一部分&#xff0c;用于游戏与Steam平台之间的交互。如果游戏中缺失这个文件&#xff0c;可能会出现无法启动、崩溃或其他问题。 一&#xff0c;详细了解stea…

数据结构8---查找

一、静态查找和动态查找 &#xff08;一&#xff09;静态查找 静态查找:数据集合稳定&#xff0c;不需要添加,删除元素的查找操作。 对于静态查找来说&#xff0c;我们不妨可以用线性表结构组织数据&#xff0c;这样可以使用顺序查找算法&#xff0c;如果我们再对关键字进行…

XSS跨站攻击漏洞

XSS跨站攻击漏洞 一 概述 1 XSS概述 xss全称为&#xff1a;Cross Site Scripting&#xff0c;指跨站攻击脚本&#xff0c;XSS漏洞发生在前端&#xff0c;攻击的是浏览器的解析引擎&#xff0c;XSS就是让攻击者的JavaScript代码在受害者的浏览器上执行。 XSS攻击者的目的就是…

【从0实现React18】 (三) 初探reconciler 带你初步探寻React的核心逻辑

Reconciler 使React核心逻辑所在的模块&#xff0c;中文名叫协调器&#xff0c;协调(reconciler)就是diff算法的意思 reconciler有什么用&#xff1f; 在前端框架出现之前&#xff0c;通常会使用 jQuery 这样的库来开发页面。jQuery 是一个过程驱动的库&#xff0c;开发者需要…

SLAM Paper Reading和代码解析

最近对VINS、LIO-SAM等重新进行了Paper Reading和代码解析。这两篇paper和代码大约在三年前就读过&#xff0c;如今重新读起来&#xff0c;仍觉得十分经典&#xff0c;对SLAM算法研发具有十分重要的借鉴和指导意义。重新来读&#xff0c;对其中的一些关键计算过程也获得了更新清…

java的输出流File OutputStream

一、字节输出流FileOutput Stream 1、定义 使用OutputStream类的FileOutput Stream子类向文本文件写入的数据。 2.常用构造方法 3.创建文件输出流对象的常用方式 二、输出流FileOutputStream类的应用示例 1.示例 2、实现步骤 今天的总结就到此结束啦&#xff0c;拜拜&#x…

JVM专题七:JVM垃圾回收机制

JVM专题六&#xff1a;JVM的内存模型中&#xff0c;我们介绍了JVM内存主要分哪些区域&#xff0c;这些区域分别是干什么的&#xff0c;同时也举了个例子&#xff0c;在运行过程种各个区域数据是怎样流转的。细心的小伙伴可能发现一个问题&#xff0c;在介绍完方法弹栈以后就没有…

SQL找出所有员工当前薪水salary情况

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站&#xff0c;这篇文章男女通用&#xff0c;看懂了就去分享给你的码吧。 描述 有一个薪水表…

13.爬虫---PyMongo安装与使用

13.PyMongo安装与使用 1.安装 PyMongo2.使用PyMongo2.1连接数据库和集合2.2增加数据2.3修改数据2.4查询数据2.5删除数据 3.总结 MongoDB 安装可以看这篇文章MongoDB安装配置教程&#xff08;详细版&#xff09; 1.安装 PyMongo PyMongo 是Python中用于连接MongoDB数据库的库&a…