深度学习(五)—— 卷积神经网络(CNN)

news2024/11/26 0:57:08

卷积神经网络(CNN)

  • 1 CNN的组成
  • 2 卷积层
    • 2.1 卷积的计算
    • 2.2 多通道卷积
    • 2.3 多卷积核卷积
    • 2.4 特征图大小
    • 2.5 卷积层 api 实现
  • 3 池化层
    • 3.1 最大池化
    • 3.2 平均池化
  • 4 全连接层
  • 5 CNN的构建
    • 5.1 数据加载
    • 5.2 数据处理
    • 5.3 模型搭建
    • 5.4 模型编译
    • 5.5 模型训练
    • 5.6 模型评估

1 CNN的组成

CNN网络受人类视觉神经系统的启发,人类的视觉原理:从原始信号摄入开始(瞳孔摄入像素 Pixels),接着做初步处理(大脑皮层某些细胞发现边缘和方向),然后抽象(大脑判定,眼前的物体的形状,是圆形的),然后进一步抽象(大脑进一步判定该物体是只人脸)。下面是人脑进行人脸识别的一个示例:

在这里插入图片描述

CNN网络主要有三部分构成:

  • 卷积层
    • 提取图像中的局部特征
  • 池化层
    • 大幅降低参数量级(降维)
  • 全连接层
    • 类似人工神经网络的部分,用来输出想要的结果

整个CNN网络结构如下图所示:

在这里插入图片描述

2 卷积层

卷积层是卷积神经网络中的核心模块,卷积层的目的是提取输入特征图的特征,如下图所示,卷积核可以提取图像中的边缘信息。

在这里插入图片描述

2.1 卷积的计算

卷积运算本质上就是在滤波器和输入数据的局部区域间做点积。

在这里插入图片描述
Output 左上角点的计算方法:
在这里插入图片描述

同理可以计算其他各点,得到最终的卷积结果:

在这里插入图片描述
最后一点的计算方法是:

在这里插入图片描述

在上述卷积过程中,特征图比原始图减小了很多,我们可以在原图像的周围进行 padding,来保证在卷积过程中特征图大小不变。

在这里插入图片描述
按照步长为1来移动卷积核,计算特征图如下所示:

在这里插入图片描述

如果我们把stride增大,比如设为2,也是可以提取特征图的,如下图所示:

在这里插入图片描述

2.2 多通道卷积

实际中的图像都是多个通道组成的,我们怎么计算卷积呢?

在这里插入图片描述
计算方法如下:当输入有多个通道(channel)时(例如图片可以有 RGB 三个通道),卷积核需要拥有相同的channel数,每个卷积核 channel 与输入层的对应 channel 进行卷积,将每个 channel 的卷积结果按位相加得到最终的 Feature Map

在这里插入图片描述
输入层的通道数等于每个卷积核的通道数

2.3 多卷积核卷积

如果有多个卷积核时怎么计算呢?当有多个卷积核时,每个卷积核学习到不同的特征,对应产生包含多个 channel 的 Feature Map,例如下图有两个 filter,所以 output 有两个 channel。

在这里插入图片描述

卷积核的个数等于输出的通道数

2.4 特征图大小

输出特征图的大小与以下参数息息相关:

  • size:卷积核/过滤器大小,一般会选择为奇数,比如有1 * 1, 3 * 3, 5 * 5
  • padding:零填充的方式
  • stride:步长

计算方法如下图所示:

在这里插入图片描述
例如,输入特征图为5x5,卷积核为3x3,外加padding 为1,则其输出尺寸为:

在这里插入图片描述
如下图所示:

在这里插入图片描述

2.5 卷积层 api 实现

tf.keras.layers.Conv2D(
    filters, kernel_size, strides=(1, 1), padding='valid', 
     activation=None
)

主要参数说明如下:

在这里插入图片描述

3 池化层

池化层降低了后续网络层的输入维度,缩减模型大小,提高计算速度,并提高了Feature Map 的鲁棒性,防止过拟合,它主要对卷积层学习到的特征图进行下采样(subsampling)处理,主要由两种。

3.1 最大池化

Max Pooling,取窗口内的最大值作为输出,这种方式使用较广泛。

在这里插入图片描述

池化层最大池化的 api 实现如下:

tf.keras.layers.MaxPool2D(
    pool_size=(2, 2), strides=None, padding='valid'
)

参数:

  • pool_size:池化窗口的大小
  • strides:窗口移动的步长,默认为1
  • padding:是否进行填充,默认是不进行填充的

3.2 平均池化

Avg Pooling,取窗口内的所有值的均值作为输出

在这里插入图片描述
池化层平均池化的 api 实现如下:

tf.keras.layers.AveragePooling2D(
    pool_size=(2, 2), strides=None, padding='valid'
)

4 全连接层

全连接层位于CNN网络的末端,经过卷积层的特征提取与池化层的降维后,将特征图转换成 一维向量 送入到全连接层中进行分类或回归的操作。

在这里插入图片描述

在 tf.keras 中全连接层使用 tf.keras.dense 实现

5 CNN的构建

我们构建卷积神经网络在mnist数据集上进行处理,如下图所示:LeNet-5 是一个较简单的卷积神经网络, 输入的二维图像,先经过两次卷积层,池化层,再经过全连接层,最后使用softmax分类作为输出层。

在这里插入图片描述

5.1 数据加载

导入工具包:

import tensorflow as tf
# 数据集
from tensorflow.keras.datasets import mnist

加载数据集:

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

5.2 数据处理

卷积神经网络的输入要求是:N H W C ,分别是图片数量,图片高度,图片宽度和图片的通道,因为是灰度图,通道为1

# 数据处理:n,h,w,c
# 训练集数据
train_images = tf.reshape(train_images, (train_images.shape[0],train_images.shape[1],train_images.shape[2], 1))
print(train_images.shape) # (60000, 28, 28, 1)
# 测试集数据
test_images = tf.reshape(test_images, (test_images.shape[0],test_images.shape[1],test_images.shape[2], 1))

5.3 模型搭建

Lenet-5模型输入的二维图像,先经过两次卷积层、池化层,再经过全连接层,最后使用 softmax 分类作为输出层,模型构建如下:

# 模型构建
net = tf.keras.models.Sequential([
    # 卷积层:6个5*5的卷积核,激活是sigmoid
    tf.keras.layers.Conv2D(filters=6,kernel_size=5,activation='sigmoid',input_shape=  (28,28,1)),
    # 最大池化
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
    # 卷积层:16个5*5的卷积核,激活是sigmoid
    tf.keras.layers.Conv2D(filters=16,kernel_size=5,activation='sigmoid'),
    # 最大池化
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
    # 维度调整为1维数据
    tf.keras.layers.Flatten(),
    # 全卷积层,激活sigmoid
    tf.keras.layers.Dense(120,activation='sigmoid'),
    # 全卷积层,激活sigmoid
    tf.keras.layers.Dense(84,activation='sigmoid'),
    # 全卷积层,激活softmax
    tf.keras.layers.Dense(10,activation='softmax')

])

我们通过 net.summary() 查看网络结构:

Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D)            (None, 24, 24, 6)         156
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 12, 12, 6)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 16)          2416      
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 4, 4, 16)          0         
_________________________________________________________________
flatten_2 (Flatten)          (None, 256)               0         
_________________________________________________________________
dense_25 (Dense)             (None, 120)               30840     
_________________________________________________________________
dense_26 (Dense)             (None, 84)                10164     

dense_27 (Dense)             (None, 10)                850       
=================================================================
Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
______________________________________________________________

关于参数量计算:

  • conv1中的卷积核为5x5x1,卷积核个数为6,每个卷积核有一个bias,所以参数量为:5x5x1x6+6=156
  • conv2中的卷积核为5x5x6,卷积核个数为16,每个卷积核有一个bias,所以参数量为:5x5x6x16+16 = 2416

绘制模型结构图:

tf.keras.utils.plot_model(net)

在这里插入图片描述

5.4 模型编译

设置优化器和损失函数:

# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=0.9)
# 模型编译:损失函数,优化器和评价指标
net.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

5.5 模型训练

# 模型训练
net.fit(train_images, train_labels, epochs=5, validation_split=0.1)# 省略verbose参数,默认verbose=1

训练流程:

Epoch 1/5
1688/1688 [==============================] - 10s 6ms/step - loss: 0.8255 - accuracy: 0.6990 - val_loss: 0.1458 - val_accuracy: 0.9543
Epoch 2/5
1688/1688 [==============================] - 10s 6ms/step - loss: 0.1268 - accuracy: 0.9606 - val_loss: 0.0878 - val_accuracy: 0.9717
Epoch 3/5
1688/1688 [==============================] - 10s 6ms/step - loss: 0.1054 - accuracy: 0.9664 - val_loss: 0.1025 - val_accuracy: 0.9688
Epoch 4/5
1688/1688 [==============================] - 11s 6ms/step - loss: 0.0810 - accuracy: 0.9742 - val_loss: 0.0656 - val_accuracy: 0.9807
Epoch 5/5
1688/1688 [==============================] - 11s 6ms/step - loss: 0.0732 - accuracy: 0.9765 - val_loss: 0.0702 - val_accuracy: 0.9807

5.6 模型评估

# 模型评估
score = net.evaluate(test_images, test_labels, verbose=1)
print('Test accuracy:', score[1])

输出为:

313/313 [==============================] - 1s 2ms/step - loss: 0.0689 - accuracy: 0.9780
Test accuracy:  0.9779999852180481

与使用全连接网络相比,准确度提高了很多。

verbose是日志显示,有三个参数可选择,分别为0,1和2

  • 当verbose=0时,简单说就是不输出日志信息 ,进度条、loss、acc这些都不输出
  • 当verbose=1(默认)时,带进度条的输出日志信息
  • 当verbose=2时,为每个epoch输出一行记录,和1的区别就是没有进度条
  • 训练和评估时都默认取值为1,但在评估时参数只有0和1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/701376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Go语言使用net/http实现简单登录验证和文件上传功能

最近再看Go语言web编程,使用net/http模块编写了一个简单的登录验证和文件上传的功能,在此做个简单记录。 目录 1.文件目录结构 2.编译运行 3.用户登录 4.文件上传 5.mime/multipart模拟form表单上传文件 代码如下: package mainimport …

【C语言】递归实战,通过几个例子带你深入走进递归算法

君兮_的个人主页 勤时当勉励 岁月不待人 C/C 游戏开发 Hello,这里是君兮_,今天给大家带来一篇递归的实战教学文章,由于递归算法不仅对于初学者十分不易理解并且在我们以后的数据结构中也非常重要。我们今天就通过几个应用递归的实际例子来给…

Apache Doris 在头部票务平台的应用实践:报表开发提速数十倍、毫秒级查询响应

作者|国内某头部票务平台 大数据开发工程师 刘振伟 本文导读: 随着在线平台的发展,票务行业逐渐实现了数字化经营,企业可以通过在线销售、数字营销和数据分析等方式提升运营效率与用户体验。基于此,国内某头部票务平…

【Java】Java核心 81:Git 教程(4)差异比较 版本回退

文章目录 06.GIT本地操作-差异比较目标内容小结 07.GIT本地操作-版本回退目标内容小结 在Git中,可以使用差异比较命令和版本回退命令来查看文件之间的差异并回退到早期的版本。 以下是对这些操作的简要解释: 差异比较:你可以使用git diff命…

本地Linux 部署 Dashy 并远程访问

文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 转载自cpolar极点云文章:本地Linux 部署 Dashy 并远程访问 简介 Dashy 是一个开源的自托管的导航页配置服务,具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你…

video-04-videojs配置及使用

videojs是一种轻框架,可以帮我们快速开发一个video视频组件 目录 一、参考资料 二、引入videojs 三、简单了解使用 四、配置项和事件 4.1 常用配置项 4.2 常用事件 4.3 常用方法 4.4 网络状态 4.5 播放状态 4.6 视频控制 五、实例(可直接复制…

升级iOS 17测试版后如何降级?iOS17降级教程

对于已经升级到 iOS 17 测试版的用户,如果在体验过程中,感觉到并不是那么稳定,例如出现应用程序不适配、电池续航下降、功能无法正常启用等问题,想要进行降级操作,可以参考本教程。 降级前注意事项: 1.由于…

Android 自定义手写签字板,签署姓名,签名

各位大佬好又来记笔记了~ 今天要做的是签字板,实现客户签名功能,直接看效果: 逐个进行签字,可以避免连笔导致识别不清问题。就是想要客户一个一个写,认真写~~。 下面方框显示的“王某才” 其实是三张图片,…

【算法题】动态规划中级阶段之不同路径、最小路径和

动态规划中级阶段 前言一、不同路径1.1、思路1.2、代码实现 二、不同路径 II2.1、思路2.2、代码实现 三、最小路径和3.1、思路3.3、代码实现 总结 前言 动态规划(Dynamic Programming,简称 DP)是一种解决多阶段决策过程最优化问题的方法。它…

卸载及安装docker的教程-ubuntu

一、前言 万地高楼平地起~ 二、环境 OS:Ubuntu 20.04 64 bit 显卡:NVidia GTX 2080 Ti CUDA:11.2 三、卸载docker 1、删除docker及安装时自动安装的所有包 apt-get autoremove docker docker-ce docker-engine docker-ce-*for pkg in …

linux -信号量semphore分析

linux -信号量分析 1 struct semaphore和sema_init1.1 struct semaphore1.2 sema_init 2 down3 up4 down_interruptible5 down_killable6 down_timeout7 down_trylock 基于linux-5.15分析,信号量在使用是是基于spin lock封装实现的。 1 struct semaphore和sema_ini…

爬虫入门指南:如何使用正则表达式进行数据提取和处理

文章目录 正则表达式正则表达式中常用的元字符和特殊序列案例 使用正则表达式提取数据案例存储数据到文件或数据库使用SQLite数据库存储数据的示例代码SQLite基本语法创建表格:插入数据:查询数据:更新数据:删除数据:条…

【雕爷学编程】Arduino动手做(137)---MT8870语音解码

37款传感器与执行器的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&am…

【uview calendar日历 】如何选择今天之前的数据

在日常工作中,使用uniappuview的ui组件,使用日历组件默认是无法选择当前之前的日期,现在讲下解决的方法 设置 最小的可选日期minDate,最大可选日期maxDate, 默认选中的日期,mode为multiple或range是必须为数…

自定义选项卡组件,选项可插槽html

文件夹xxtabs 四个文件 index暴露 render vue添加虚拟节点到插槽&#xff08;自定义标签结构&#xff09; tabs选项卡整体 abpaneq切换区 tabs.vue <template><div class"gnip-tab"><div class"gnip-tab-nav"><divv-for"(item,…

“sudo”组不存在”或“用户不在 sudoers 文件中。此事将被报告”

解决方法: 使用命令&#xff1a;usermod -a -G sudo tom (换成其他的用户名&#xff0c;也是一个道理)&#xff0c;不过还是不行。 实际解决还是要执行 sudo visudo &#xff0c;在这个文件中去添加用户 这样修改之后&#xff0c;保存并退出&#xff0c;亲测有效&#xff01; …

【FFmpeg实战】AAC编码介绍

AAC&#xff08;Advanced Audio Coding&#xff0c;译为&#xff1a;高级音频编码&#xff09;&#xff0c;是由Fraunhofer IIS、杜比实验室、AT&T、Sony、Nokia等公司共同开发的有损音频编码和文件格式。 对比MP3 AAC被设计为MP3格式的后继产品&#xff0c;通常在相同的比…

训练自己的ChatGPT 语言模型(一).md

0x00 Background 为什么研究这个&#xff1f; ChatGPT在国内外都受到了广泛关注&#xff0c;很多高校、研究机构和企业都计划推出类似的模型。然而&#xff0c;ChatGPT并没有开源&#xff0c;且复现难度非常大&#xff0c;即使到现在&#xff0c;没有任何单位或企业能够完全复…

Atlas200 DK A2与Arduino进行UART串口通信

我们在做一些人工智能的应用开发时往往使用人工智能开发板作为上位机&#xff08;比如我们的小滕&#xff09;&#xff0c;Arduino、stm32等作为下位机控制板&#xff0c;通过上位机进行人工智能模型的推理之后进而给下位机传输对应的控制命令实现智能控制。那么如何实现两者的…

简化交互体验——探索Gradio的ClearButton模块

❤️觉得内容不错的话&#xff0c;欢迎点赞收藏加关注&#x1f60a;&#x1f60a;&#x1f60a;&#xff0c;后续会继续输入更多优质内容❤️ &#x1f449;有问题欢迎大家加关注私戳或者评论&#xff08;包括但不限于NLP算法相关&#xff0c;linux学习相关&#xff0c;读研读博…