经典CNN(二):ResNet50V2算法实战与解析

news2025/1/10 10:30:21
  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制

 1 论文解读

    在《Identity Mappings in Deep Residual Networks》中,作者何凯明先生提出了一种新的残差单元,为区别原始的ResNet结构,这里称其为ResNetV2。

1.1 ResNetV2 & ResNet之结构对比和性能对比

    上图为原始论文中的截图,展示了ResNet和ResNetV2的结构对比,以及测试结果。根据说明可知,右图的实线表示测试误差,对应右边y轴的Test Error,虚线表示训练损失,对应左边y轴的Train Loss,x轴表示迭代次数Iterations。

  • 结构调整点:ResNet(图a中的original)结构是(卷积+BN+激活+卷积+BN)+addition+激活,而ResNetV2(图b中的proposed)结构是(BN+激活+卷积+BN+激活+卷积)+addition。对比发现,两者的总模块数和类型未发生改变,只是在顺序上做了调整。
  • 结果提升:作者使用两种不同的结构在CIFAR-10数据集上做测试,模型使用1001层的RestNet模型,从右图结果可以看出,ResNetV2的测试集错误率(4.92%)明显低于原始的ResNet(7.61%)。loss方面,同一个Iteration上,ResNetV2都低于ResNet。

1.2 残差的不同尝试

    上图是论文中作者对残差结构的shortcut部分进行的不同尝试,从图示说明中得知,为简化插图,我们不显示BN层,图中所有的conv层之后都有BN层。其测试结果如下表所示,该表是使用ResNet-110在CIFAR-10测试集上的分类错误,对所有残差单元应用了不同类型的shortcut connections,当测试误差大于20%时,标注为“fail”。测试结果表明,原始的ResNet结构是最好的,即恒等映射是最好的

 1.3 激活的不同尝试

     使用不同的激活函数进行尝试,由此可见,最好的结果是full pre-activation,其次是original。

2 代码实现

2.1 开发环境

电脑系统:ubuntu16.04

编译器:Jupter Lab

语言环境:Python 3.7

深度学习环境:tensorflow

2.2 数据准备代码

    这部分代码包括设置GPU和数据处理部分,其中数据处理包括导入数据、查看数据、加载数据、可视化数据、检查数据、配置数据。

2.2.1 设置GPU

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True) # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]], "GPU")

2.2.2 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号

import os, PIL, pathlib
import numpy as np

from tensorflow import keras
from tensorflow.keras import layers,models

data_dir = "../data/bird_photos"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

2.2.3 加载数据

batch_size = 8
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_Names = train_ds.class_names
print("class_Names:",class_Names)

    输出结果如下:

2.2.4 可视化数据

plt.figure(figsize=(10, 5)) # 图形的宽为10,高为5
plt.suptitle("imshow data")

for images,labels in train_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_Names[labels[i]])
        plt.axis("off")

    输出结果如下:

2.2.5 检查数据

for image_batch, lables_batch in train_ds:
    print(image_batch.shape)
    print(lables_batch.shape)
    break

    输出结果如下:

2.2.6 配置数据集

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

2.3 ResNet50V2模型复现

2.3.1 ResNet50V2网络结构

    如下所示,最左边是ResNet50V2的网络结构,与ResNet50类似,但又比ResNet更复杂一点,包括了三个基本的Residual block,分别用蓝色、橙色、灰色块表示。右边是这三个Residual block的网络结构。仔细查看可以得出,三个Residual block左边的分支模块和顺序完全一样,分别为(BN+ReLu)+(Conv2D+BN+ReLU)+ZeroPad+(Conv2D+BN+ReLU)+Conv2D,右边分支有所差异,因此在编写代码的时候,可以共用一个函数,根据传入参数的不同而产生相应的Residual block。

 2.3.2 ResNetV2代码

import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.models import Model


''' 残差块
Arguments:
    x: 输入张量
    filters: integer, filters, of the bottleneck layer.
    kernel_size: default 3, kernel size of the bottleneck layer.
    stride: default 1, stride of the first layer.
    conv_shortcut: default False, use convolution shortcut if True, otherwise identity shortcut.
    name: string, block label.
Returns:
    Output tensor for the residual block.
'''
def block2(x, filters, kernel_size=3, stride=1, conv_shortcut=False, name=None):
    preact = layers.BatchNormalization(name=name+'_preact_bn')(x)
    preact = layers.Activation('relu', name=name+'_preact_relu')(preact)
    
    if conv_shortcut:
        shortcut = layers.Conv2D(4*filters, 1, strides=stride, name=name+'_0_conv')(preact)
    else:
        shortcut = layers.MaxPooling2D(1, strides=stride)(x) if stride>1 else x
    
    x = layers.Conv2D(filters, 1, strides=1, use_bias=False, name=name+'_1_conv')(preact)
    x = layers.BatchNormalization(name=name+'_1_bn')(x)
    x = layers.Activation('relu', name=name+'_1_relu')(x)
    
    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name+'_2_pad')(x)
    x = layers.Conv2D(filters, kernel_size, strides=stride, use_bias=False, name=name+'_2_conv')(x)
    x = layers.BatchNormalization(name=name+'_2_bn')(x)
    x = layers.Activation('relu', name=name+'_2_relu')(x)
    
    x = layers.Conv2D(4*filters, 1, name=name+'_3_conv')(x)
    x = layers.Add(name=name+'_out')([shortcut, x])
    return x


def stack2(x, filters, blocks, stride1=2, name=None):
    x = block2(x, filters, conv_shortcut=True, name=name+'_block1')
    for i in range(2, blocks):
        x = block2(x, filters, name=name+'_block'+str(i))
    x = block2(x, filters, stride=stride1, name=name+'_block'+str(blocks))
    return x


''' 构建ResNet50V2 '''
def ResNet50V2(include_top=True,  # 是否包含位于网络顶部的全链接层
               preact=True,  # 是否使用预激活
               use_bias=True,  # 是否对卷积层使用偏置
               weights='imagenet',
               input_tensor=None,  # 可选的keras张量,用作模型的图像输入
               input_shape=None,
               pooling=None,
               classes=1000,  # 用于分类图像的可选类数
               classifer_activation='softmax'):  # 分类层激活函数
    img_input = layers.Input(shape=input_shape)
    x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name='conv1_pad')(img_input)
    x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)
    
    if not preact:
        x = layers.BatchNormalization(name='conv1_bn')(x)
        x = layers.Activation('relu', name='conv1_relu')(x)
    
    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
    x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
    
    x = stack2(x, 64, 3, name='conv2')
    x = stack2(x, 128, 4, name='conv3')
    x = stack2(x, 256, 6, name='conv4')
    x = stack2(x, 512, 3, stride1=1, name='conv5')
    
    if preact:
        x = layers.BatchNormalization(name='post_bn')(x)
        x = layers.Activation('relu', name='post_relu')(x)
    if include_top:
        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
        x = layers.Dense(classes, activation=classifer_activation, name='predictions')(x)
    else:
        if pooling=='avg':
            # GlobalAveragePooling2D就是将每张图片的每个通道值各自加起来再求平均,
            # 最后结果是没有了宽高维度,只剩下个数与平均值两个维度
            # 可以理解成变成了多张单像素图片
            x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
        elif pooling=='max':
            x = layers.GlobalMaxPooling2D(name='max_pool')(x)
    
    model = Model(img_input, x, name='resnet50v2')
    return model
    
model = ResNet50V2(input_shape=(224,224,3))
model.summary()

    运行结果如下所示(由于输出结果太长,只截取最前面和最后面部分内容): 

 (中间部分省略)

2.4 设置loss和优化器 

     在对模型进行训练之前,还需要对其设置,包括:

损失函数(loss):用于衡量模型在训练期间的准确率
优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
指标(metrics):用于监控训练和测试步骤。下面的代码使用了准确率,即被正确分类的图像的比率。

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)
 
model.compile(optimizer="adam",
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

2.5 训练模型

epochs = 10
 
history = model.fit(
                train_ds,
                validation_data=val_ds,
                epochs=epochs)

    模型训练时结果:

2.6 模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
 
loss = history.history['loss']
val_loss = history.history['val_loss']
 
epochs_range = range(epochs)
 
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("ResNet test")
 
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation loss')
plt.legend(loc='upper right')
plt.title('Training and Validation loss')
plt.show()

    结果如图所示: 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/783710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MOS,PCB如何添加散热孔、过孔

一、什么是 PCB 散热孔? 散热孔是利用贯通PCB板的通道(过孔)使热量传导到背面来散热的手法,配置在发热体的正下方或尽可能靠近发热体。 散热孔是利用PCB板来提高表面贴装部件散热效果的一种方法,在结构上是在PCB板上…

element-ui里的el-table在grid布局下切换数据有滚动条时不断增加?

今天在项目里面遇到了这个问题,相当炸裂,看了半天都没有看出什么问题,很是逆天,记录一下 下面使用代码情景复现一下:el-table 是在 grid 布局下面的,不是子层级,中间还有一层 content 的元素包…

【数据结构】差分数组

【数据结构】差分数组 差分数组二维差分数组二维数组的前缀和 差分数组 如果给定一个包含1000万个元素的数组,同时假定会有频繁区间修改操作,但是不会有频繁的查询操作,比如对某个范围【l,r】内的数字加上某个数字,此时…

Java基础-->异常

什么是异常? 异常:异常就是代表程序出现的问题 误区:不是让我们以后不出现异常,而是程序出了异常之后该如何处理 Error 代表系统级别的错误(属于原重问题) 系统一旦出现问题,sun公司会把这些…

数据库应用:Mycat实现读写分离

目录 一、理论 1.Mycat 2.Mycat安装启动 3.Mycat搭建读写分离 4.垂直分库 5.水平分表 6.Mycat高可用 7.Mycat安全设置 8.Mycat监控工具 二、实验 1.Mycat读写分离 2.Mycat监控安装 三、问题 1.Mycat命令无法补全 2.Mycat启动失败 3.zookeeper启动报错 四、总结…

基于SpringBoot+Vue的冬奥会科普平台设计与实现

博主介绍: 大家好,我是一名在Java圈混迹十余年的程序员,精通Java编程语言,同时也熟练掌握微信小程序、Python和Android等技术,能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架…

Java通过URL对象实现简单爬虫功能

目录 一、URL类 1. URL类基本概念 2. 构造器 3. 常用方法 二、爬虫实例 1. 爬取网络图片(简易) 2. 爬取网页源代码 3. 爬取网站所有图片 一、URL类 1. URL类基本概念 URL:Uniform Resource Locator 统一资源定位符 表示统一资源定位…

动态规划--回文串问题

一)回文子串: 647. 回文子串 - 力扣&#xff08;LeetCode&#xff09; 思路1:暴力枚举: for(int i0;i<array.length;i) for(int ji;j<array.length;j) 我们的中心思路就是枚举出所有的子字符串&#xff0c;然后进行判断所有的子串是否是回文串 思路2:中心扩散: 我们从左向…

​MySQL高阶语句(三)

目录 1、内连接 2、左连接 3、右连接&#xff1a; 二、存储过程⭐⭐⭐ 4. 调用存储过程 5.查看存储过程 5.1 查看存储过程 5.2查看指定存储过程信息 三. 存储过程的参数 3.1存储过程的参数 3.2修改存储过程 四.删除存储过程 MySQL 的连接查询&#xff0c;通常都是将来…

ElasticSearch学习--RestClient及案例

目录 RestClient查询文档 快速入门 总结 全文检索&#xff08;match&#xff09;查询 精确查询 复合查询 查询总结 排序&#xff0c;分页 高亮 RestClient查询文档 快速入门 总结 全文检索&#xff08;match&#xff09;查询 多种查询的差异都在做类型和条件上&#x…

JS 自定义的悬浮窗被浏览器遮挡问题解决方案

遮挡问题解决思路&#xff0c;首先拿到外层的DOM元素div的宽高&#xff0c;然后根据鼠标悬浮事件的元素e e.clientX表距离页面窗口宽的位置 e.clientY代表距离页面窗口高的位置 然后设置这个悬浮窗为200px 那个这个div的宽高 dom.getElementById(xxxx).cliengHeight dom.g…

FutureTask

Future接口 Future接口&#xff08;FutureTask实现类&#xff09;定义了操作异步任务执行一些方法&#xff0c;如获取异步任务执行的结果、取消任务的执行、判断任务是否取消、判断任务执行是否完成等。它提供了一种并行异步计算的功能。比如主线程让子线程去执行任务&#xff…

C语言两种方法求证大小端存储

目录 什么是大小端存储&#xff1f; 字节序的概念&#xff1a; 小端字节序存储&#xff1a; 大端字节序存储&#xff1a; 什么是低位字节、高位字节&#xff1f; 记忆技巧&#xff1a; C语言求证大小端存储 法一&#xff1a; 法二&#xff1a; 总结&#xff1a; 什么是…

CAXA中.exb或者.dwg文件保存为PDF

通常CAXAZ中的文件为.exb或者.dwg格式&#xff0c;我们想打印或者保存为PDF文件格式&#xff0c;那么就用一下的方法&#xff1a; CAXA文件如图所示&#xff1a; 框选出你要打印的图纸&#xff01;&#xff01;&#xff01;&#xff01; 我们选择"菜单"->"…

用户订单信息案例

需求: 用户输入商品价格和商品数量&#xff0c;以及收货地址&#xff0c;可以自动打印订单信息 分析: ① 需要输入3个数据&#xff0c;所以需要3个变量来存储price num address ② 需要计算总的价格total ③ 页面打印生成表格, 里面填充数据即可 ④ 记得最好使用模板字符串 【…

java.io.InputStreamReader的read()函数返回值是字符对应的Unicode码点

java.io.InputStreamReader的read()函数定义&#xff1a; https://docs.oracle.com/en/java/javase/19/docs/api/java.base/java/io/InputStreamReader.html#read() 这个返回的值其实就是解码后的字符对应的Unicode码点&#xff08;Unicode code point&#xff09;。 举例 例如…

MySQL表的管理

目录 1.mysql中&#xff0c;数据存储过程分为四步 2.数据库命名规则 3.创建数据库 4.管理数据库的方法 5.修改数据库&#xff08;一般不改&#xff0c;最多改字符集&#xff09; 6.删除数据库 7.如何创建数据表 8.修改表 9.重命名表 10.删除表&#xff08;注意⚠️无…

Java日志slf4j+logback

一、maven依赖 在pom文件增加slf4jlogback依赖 <!-- 版本配置 --> <properties><slf4j.version>1.7.21</slf4j.version><logback.version>1.1.7</logback.version> </properties><dependencies><!-- slf4j依赖包 -->&…

JVM源码剖析之达到什么条件进行JIT优化

版本信息&#xff1a; jdk版本&#xff1a;jdk8u40 思想至上 技术经过数百年的迭代&#xff0c;如今虚拟机中都存在JIT模块&#xff0c;JVM中Hotspot&#xff0c;Android虚拟机中dalvik、Art等等。并且存在一个共性&#xff0c;全部都是解释器和JIT共存。当然&#xff0c;如今…