经典CNN(一):ResNet-50算法实战与解析

news2025/1/11 22:46:20
  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制

 1 ResNet理论

    深度残差网络ResNet(deep residual network)在2015年由何凯明等提出,因为它简单与实用并存,随后很多研究都是建立在ResNet-50或者ResNet-101基础上完成。

    ResNet主要解决深度卷积网络在深度加深时候的”退化“问题。在一般的卷积神经网络中,增大网络深度后带来的第一个问题就是梯度消失或梯度爆炸,这个问题Szegedy提出的BN层后被顺利解决。BN层能对各层的输出做归一化,这样梯度在反向层层传递后仍能保持大小稳定,不会出现过大或过小的情况。但是作者发现加了BN层后再加大深度仍然不容易收敛,其提到了第二个问题--准确率下降问题:层级大到一定程度时准确率就会饱和,然后迅速下降,这种下降既不是梯度消失引起的,也不是过拟合造成的,而是由于网络过于复杂,以至于光靠不加约束的放养式的训练很难达到理想的准确率。

    准确率下降问题不是网络结构本身的问题,而是现有的训练方式不够理想造成的。当前广泛使用的优化器,无论是SGD,还是RMSProp,或是Adam,都无法在网络深度变大后达到理论上最优的收敛结果。

    作者在文中证明了只要有合适的网络结构,更深的网络肯定会比较浅的网络效果好。证明过程也很简单:假设在一种网络A的后面添加几层形成新的网络B,如果增加的层级只是对A的输出做了个恒等映射(identity mapping),即A的输出经过新增的层级变成B的输出后没有发生变化,这样网络A和网络B的错误率就是相等的,也就证明了加深后的网络不会比加深前的网络效果差。

图1 残差模块​​​​

     何凯明提出了一种残差结构来实现上述恒等映射(如上图所示):整个模块除了正常的卷积层输出外,还有一个分支把输入直接连到输出上,该分支输出和卷积的输出做算术相加得到最终的输出,用公式表达就是H(x)=F(x)+x,其中x是输入,F(x)是卷积分支的输出,H(x)是整个结构的输出。可以证明如果F(x)分支中所有参数都是0,H(x)=x,即H(x)与x为恒等映射。残差结构是人为的制造了恒等映射,能让整个结构朝着恒等映射的方向去收敛,确保最终的错误率不会因为深度的变大而越来越差。如果一个网络通过简单的手工设置参数值就能达到想要的结果,那这种结构就很容易通过训练来收敛到该结果,这是一条设计复杂的网络时通用的规则。

图2 两种残差模块

    图2 左边的单元为ResNet两层的残差单元,两层的残差单元包含两个相同输出通道数的3*3卷积,只是用于较浅的ResNet网络,对较深的网络主要使用三层的残差单元。三层的残差单元又称为bottleneck结构,先用一个1*1卷积进行降维,最后用1*1升维恢复原有的维度。另外,如果有输入输出维度不同的情况,可以对输入做一个线性映射变换维度,再连接后面的层。三层的残差单元对于相同数量的层又减少了参数量,因此可以拓展更深的模型。通过残差单元的组合有经典的ResNet-50,ResNet-101等网络结构。

2 前期工作

2.1 开发环境

电脑系统:ubuntu16.04

编译器:Jupter Lab

语言环境:Python 3.7

深度学习环境:tensorflow

2.2 设置GPU

    如果设备上支持GPU就使用GPU,否则注释掉这部分代码

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True) # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]], "GPU")

2.3 导入数据并查看数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号

import os, PIL, pathlib
import numpy as np

from tensorflow import keras
from tensorflow.keras import layers,models

data_dir = "../data/bird_photos"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

3 数据预处理

    数据集中的种类分别为Bananaquit、Black Skimmer、Black Throated Bushtiti、Cockatoo,他们的数量分别为下表所示:

文件夹数量
Bananaquit166
Black Skimmer111
Black Throated Bushtiti122
Cockatoo166

3.1 加载数据

    使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中。同时,我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

batch_size = 8
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_Names = train_ds.class_names
print("class_Names:",class_Names)

    结果输出如下所示:

3.2 可视化数据

plt.figure(figsize=(10, 5)) # 图形的宽为10,高为5
plt.suptitle("imshow data")

for images,labels in train_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_Names[labels[i]])
        plt.axis("off")

    结果输出如下所示: 

 

     单独查看index为1的图像,结果如下所示,与上图结果一致。

plt.imshow(images[1].numpy().astype("uint8"))

 3.3 再次检查数据

for image_batch, lables_batch in train_ds:
    print(image_batch.shape)
    print(lables_batch.shape)
    break

     其中:

① Image_batch是形状的张量(8,224,224,3),这是一批形状为240*240*3的8张图片,最后一维的3是指彩色3通道RGB;

② label_batch是形状(8,)的张量,是这8张图片对应的标签。

3.4 配置数据集

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

    其中:

① shuffle():打乱数据

② prefetch():预取数据,加速运行

③ cache():将数据集缓存到内存当中,加速运行

4 残差网络(ResNet)介绍

4.1 残差网络解决了什么

    残差网络是为了解决神经网络隐藏层过多时,而引起的网络退化问题。退化(degradation)问题是指:当网络隐藏层变多时,网络的准确度达到饱和,然后急剧退化,而且这个退化不是由于过拟合引起的。

拓展:深度神经网络的“两朵乌云”

  • 梯度弥散/爆炸

简单来讲就是网络太深了,会导致模型训练难以收敛。这个问题可以被标准初始化和中间层正规化的方法有效控制。

  • 网络退化

随着网络深度增加,网络的表现先是逐渐增加至饱和,然后迅速下降,这个退化不是由于过拟合而引起的。

4.2 ResNet-50介绍

     ResNet-50有两个基本的块,分别名为Conv_Block和Identity Block,其网络结果如下图所示,左边是ResNet-50的整体网络结构,中间是Conv Block的网络结构,右边是Identity Block的网络结构,ResNet-50中包含多个Conv Block和Identity Block的不同组合。

image.png

 5 构建ResNet-50网络模型

    此为本文重点,按照上图构建ResNet-50.

from keras import layers

from keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Model

def identity_block(input_tensor, kernel_size, filters, stage, block):
    filters1, filters2, filters3 = filters
    
    name_base = str(stage) + block + '_identity_block_'
    
    x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base+'bn1')(x)
    x = Activation('relu', name=name_base+'relu1')(x)
    
    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base+'bn2')(x)
    x = Activation('relu', name=name_base+'relu2')(x)
    
    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base+'bn3')(x)
    
    x = layers.add([x, input_tensor], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,2)):
    filters1, filters2, filters3 = filters
    
    res_name_base = str(stage) + block + '_conv_block_res_'
    name_base = str(stage) + block + '_conv_block_'
    
    x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base+'bn1')(x)
    x = Activation('relu', name=name_base+'relu1')(x)
    
    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base+'bn2')(x)
    x = Activation('relu', name=name_base+'relu2')(x)
    
    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base+'bn3')(x)
    
    shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)
    shortcut = BatchNormalization(name=res_name_base+'bn')(shortcut)
    
    x = layers.add([x, shortcut], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def ResNet50(input_shape=[224,224,3], classes=1000):
    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3,3))(img_input)
    
    x = Conv2D(64, (7, 7), strides=(2,2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3,3), strides=(2,2))(x)
    
    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1,1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
    
    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
    
    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
    
    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
    
    x = AveragePooling2D((7, 7), name='avg_pooling')(x)    
    x = Flatten()(x)
    
    x = Dense(classes, activation='softmax', name='fc1000')(x)
    
    model = Model(img_input, x, name='resnet50')
    
    # 加载预训练模型
    model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")
    
    return model

model = ResNet50()
model.summary()

    运行结果如下所示(由于输出结果太长,只截取最前面和最后面部分内容):

(中间部分省略)

6 编译

    在对模型进行训练之前,还需要对其设置,包括:

  • 损失函数(loss):用于衡量模型在训练期间的准确率
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。下面的代码使用了准确率,即被正确分类的图像的比率。
# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)

model.compile(optimizer="adam",
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

 7 训练模型

epochs = 10

history = model.fit(
                train_ds,
                validation_data=val_ds,
                epochs=epochs)

    结果显示如下。设置epochs为10,训练集和测试集的准确率在第7个epoch效果最好,分别为99.34%和93.81%。

8 模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("ResNet test")

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation loss')
plt.legend(loc='upper right')
plt.title('Training and Validation loss')
plt.show()

    结果显示如下:

 9 预测

# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(10, 5)) # 图形的宽为10,高为5
plt.suptitle('ResNet test')

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i+1)
        
        # 显示图片
        plt.imshow(images[i].numpy().astype("uint8"))
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0)
        
        # 使用模型预测图片中的鸟类
        predictions = model.predict(img_array)
        plt.title(class_Names[np.argmax(predictions)])
        
        plt.axis("off")

    结果显示如下。由于训练的还不够,在测试的两个 Cockatoo被误判为了Black Skimmer。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/751702.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hutool工具类 -集常用工具类为一体 - 工具类之大成

文章目录 说在前面的话简介gitee介绍项目介绍 网址gtiee 网址github 网址 安装pom依赖引入 :下载jar 文档中文文档中文备用文档参考API视频介绍 部分截图首页包含组件(总)IO流相关部分工具类(Util)集合类HTTP客户端 功能不再一一赘述和截图,具体请查看官…

详解TCP协议

TCP协议段格式 序号和确认序号:在真实服务器和客服端通信过程中请求是并行执行的,这会导致到达是乱序的,所以才会有序号这个东西,确认序号是对方应答时返回的,例如序号发送到1,确认序号会返回2,…

计算机网络 day6 arp病毒 - ICMP协议 - ping命令 - Linux手工配置IP地址

目录 arp协议 arp病毒\欺骗 arp病毒的运行原理 arp病毒产生的后果: 解决方法: ICMP协议 ICMP用在哪里? ICMP协议数据的封装过程 ​编辑 为什么icmp协议封装好数据后,还要加一个ip包头,再使用ip协议再次进…

springboot农机电招平台

本系统为了数据库结构的灵活性所以打算采用MySQL来设计数据库,而java技术,B/S架构则保证了较高的平台适应性。本文主要介绍了本系统的开发背景,所要完成的功能和开发的过程,主要说明了系统设计的重点、设计思想。 本系统主要是设…

关于java垃圾回收的小结

一、为什么要有垃圾回收 我们每次创建对象都需要在栈上开辟空间,堆上使用内存,如果我们只是开辟了这个空间,而不去释放他,那么再大的内存和空间也会有满的一天,所以我们在Java中引入了GC(垃圾回收机制&…

Foxit PDF ActiveX 5.9.8 Crack

Foxit PDF SDK ActiveX 即时添加PDF显示功能至Windows应用程序,快速投放市场,可视化编程组件功能强大且易于使用的PDF软件开发工具包 对于刚接触PDF或不愿投入过多精力学习PDF技术的产品管理者及开发者来说,Foxit PDF SDK ActiveX无疑是理想…

中国1km分辨率逐月平均气温数据集(1901-2022)

时间分辨率月空间分辨率1km - 10km共享方式开放获取数据大小9.71 GB数据时间范围 1901.1-2022.12 数据集摘要 该数据为中国逐月平均温度数据,空间分辨率为0.0083333(约1km),时间为1901.1-2022.12。数据格式为NETCDF,即.nc格式。数据单位为0.1 ℃。该数据集是根据CRU发布的…

对Vue组件化开发思想的一些理解

目录 组件的分类 为什么需要组件化开发 如何设计组件 组件间通信 组件系统是 Vue的一个重要概念,让我们可以用独立可复用的小组件来构建大型应用。几乎任意类型的应用的界面都可以抽象为一个组件树: 写一个 Vue 项目,其实就是在写一个个的…

接口测试 react+unittest+flask 接口自动化测试平台

目录 1 前言 2 框架 2-1 框架简介 2-2 框架介绍 2-3 框架结构 3 平台 3-1 平台组件图 1 新建用例 2 生成测试任务 3 执行并查看测试报告 3-2 用例管理 3-2-1 用例设计 3-3 任务管理 3-3-1 创建任务 3-3-2 执行任务 3-3-3 测试报告 3-3-4 邮件通知 1 前言 构建…

idea新建xml模板设置,例如:mybatis-config

在idea怎么新建mapper.xml文件&#xff0c;具体操作步骤和结果如下&#xff0c;其他文件也是可以自定义模板的流程和步骤一致&#xff01; 效果如下&#xff1a; 步骤如图&#xff1a; step1&#xff1a; step2&#xff1a; 文件内容&#xff1a; <?xml version"…

Android.mk 文件使用解析

和你一起终身学习&#xff0c;这里是程序员Android 经典好文推荐&#xff0c;通过阅读本文&#xff0c;您将收获以下知识点: 一、Android.mk 简介二、Android.mk 的基本格式三、Android.mk 深入学习一四、 Android.mk 深入学习二五、 Android.mk 深入学习三六、 Android.mk 判断…

C++【哈希表的模拟实现】

✨个人主页&#xff1a; 北 海 &#x1f389;所属专栏&#xff1a; C修行之路 &#x1f383;操作环境&#xff1a; Visual Studio 2019 版本 16.11.17 文章目录 &#x1f307;前言&#x1f3d9;️正文1、模拟实现哈希表&#xff08;闭散列&#xff09;1.1、存储数据结构的定义1…

MySQL函数以及存储过程

创建表并插入数据‘ 字段名 数据类型 主键 外键 非空 唯一 自增 id INT 是 否 是 是 否 name VARCHAR(50) 否 否 是 否 否 glass VARCHAR(50) 否 否 是 否 否 sch 表内容 id name glass 1 xiaommg glass 1 2 xiaojun glass 2 mysql> select * from sch; -------------------…

火车头采集器下载中文图片地址报错:发生错误终止..

火车头采集器下载中文图片地址报错&#xff1a;发生错误终止.. 报错信息 该问题时网友发现的&#xff0c;采集的内容中图片URL地址包含中文字符。 然后在采集内容时火车头自动下载图片就提示&#xff1a;发生错误终止&#xff0c;远程服务器返回错误&#xff1a;&#xff08…

MySQL 主从延迟的常见原因及解决方法

主从延迟作为 MySQL 的痛点已经存在很多年了&#xff0c;以至于大家都有一种错觉&#xff1a;有 MySQL 复制的地方就有主从延迟。 对于主从延迟的原因&#xff0c;很多人将之归结为从库的单线程重放。 但实际上&#xff0c;这个说法比较片面&#xff0c;因为很多场景&#xf…

我司的短信接口被刷了

如何发现的 成本分摊系统&#xff0c;将成本分摊给业务部门时&#xff0c;业务部门对账&#xff0c;发现某一类型的短信用量上涨了100多倍 排查调用来源时&#xff0c;发现来源为C端用户&#xff0c;由于调用量异常高&#xff0c;业务反馈近期无活动&#xff0c;因此怀疑被刷…

GAMES101 作业0

Visual Studio 2019下环境配置 课上提供的环境是Linux, 还需要安装Vitrual Box和创建虚拟机&#xff0c;省事就直接在Windows系统下Visual Studio下操作了。 简单的环境配置&#xff1a; 下载Eigen 的库在工程属性中添加目录&#xff1a; 2处地方 注意&#xff1a; 刚添加完…

CONTAINER = ALL是ALTER USER语句的默认值

连接到root时查看有关root&#xff0c;CDB和PDB的数据 当公用用户执行查询时&#xff0c;可以限制X $表和V $&#xff0c;GV $和CDB_ *视图的视图信息。X$表和这些视图包含有关应用程序root及其关联应用程序PDB的信息&#xff0c;或者如果连接到CDB root&#xff0c;则是整个C…

基于非支配排序遗传算法NSGAII的综合能源优化调度(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

松鼠回家(最短路+二分)

D-松鼠回家_2023河南萌新联赛第&#xff08;一&#xff09;场&#xff1a;河南农业大学 (nowcoder.com) #include<bits/stdc.h> using namespace std; #define int long long const int N2e510; map<int,int>a; int n,m,st,ed,h; struct node{int x,y; }; vector&l…