ResNet残差网络

news2025/1/20 5:44:27

ResNet

目的

Resnet网络是为了解决深度网络中的退化问题,即网络层数越深时,在数据集上表现的性能却越差。

原理

ResNet的单元结构如下:

在这里插入图片描述

类似动态规划的选择性继承,同时会在训练过程中逐渐增大(/缩小)该单元中权重层的参数,主要取决于是否是直接继承前面块更优。

实现

对于ResNet50及以上来说,采用的单元块是Bottleneck模块。

在实现Bottleneck模块前,需要先对ResNet中使用到的卷积核进行简化定义。

  1. 首先是卷积核kernel_sizef分别为1和3的定义:

    def conv1x1(in_channel, out_channel, stride=1):
        return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False)
    
    def conv3x3(in_channel, out_channel, stride=1):
        return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
    
  2. 接着定义Bottleneck单元模块:

    这里有一个涉及到梯度是否能计算的问题,如果在bn3之后又进行了一次relu操作,然后再自加等,由于relu操作是限定为原地进行的,这就会导致在反向推导时无法计算出梯度,具体原因有待考究

    class Bottleneck(nn.Module):
        extension = 4
    
        # Bottleneck only decrease the [h,w] in conv1 when stride > 1,
        # so the [h,w] is to be [(h-1)/stride+1,(w-1)/stride+1].
        # the in_channel will be change to channel*extension.
        # channel is the temp variable.
        def __init__(self, in_channel, channel, stride, downsample=None):
            super(Bottleneck, self).__init__()
    
            self.conv1 = conv1x1(in_channel, channel, stride)
            self.bn1 = nn.BatchNorm2d(channel)
    
            self.conv2 = conv3x3(channel, channel)
            self.bn2 = nn.BatchNorm2d(channel)
    
            self.conv3 = conv1x1(channel, channel * self.extension)
            self.bn3 = nn.BatchNorm2d(channel * self.extension)
    
            self.relu = nn.ReLU(inplace=True)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            identity = x
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
    
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.relu(x)
    
            x = self.conv3(x)
            x = self.bn3(x)
    
            if self.downsample is not None:
                identity = self.downsample(identity)
    
            x += identity
            x = self.relu(x)
            return x
    
  3. 最后是ResNet的主体,主体包含前向传播函数和构造集合体模块层函数:

在这里插入图片描述

ResNet总共可看作6层结构。

  • 第一层为大卷积核层,主要是以大卷积核进行卷积,同时将通道数上升到64。[h,w]=[h/2,w/2]。

  • 第二至五层是残差模块,其中残差模块由多层Bottleneck组成。多层Bottleneck的第一层的in_channel为上一个模块的out_channel,中间的in_channel则为多层Bottleneck的上一层out_channel,每个Bottleneck的plane为其in_channel的1/2。

    第二层的stride为1,但是有maxpool来使得图片尺寸缩小,其他层则通过stride=2使得图片尺寸缩小。

  • 第六层则是全连接层,使用torch.flatten进行缩维度处理。

class ResNet(nn.Module):
    # size / 32 / 7
    def __init__(self, block, layers, num_class):
        super(ResNet, self).__init__()
        # the first layer changes the channel to 64,
        # and the [h,w] will be change to [(h-1)/stride+1,(w-1)/stride+1] after the first layer.
        self.in_channel = 64
        self.block = block
        self.layers = layers

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # there are four block layers, each layers contains more than one block.
        self.stage1 = self.make_layer(self.block, 64, layers[0], stride=1)
        self.stage2 = self.make_layer(self.block, 128, layers[1], stride=2)
        self.stage3 = self.make_layer(self.block, 256, layers[2], stride=2)
        self.stage4 = self.make_layer(self.block, 512, layers[3], stride=2)

        # in the end, there will be a linear layer to classify all the classes.
        # self.avgpool = nn.AvgPool2d(7)
        self.fc = nn.Linear(512 * block.extension, num_class)

    def make_layer(self, block, plane, block_num, stride=1):
        block_list = []
        downsample = None

        # if the in_channel isn't equal to the out_channel,
        # downsample will be needed to process the in_channel to same size as the out_channel
        # so that the in_channel can be added to the out_channel to achieve the resnet struct.
        if stride != 1 or self.in_channel != plane * block.extension:
            downsample = nn.Sequential(
                conv1x1(self.in_channel, plane * block.extension, stride),
                nn.BatchNorm2d(plane * block.extension)
            )

        conv_block = block(self.in_channel, plane, stride, downsample=downsample)

        # the first block's in_channel is different to the another block_num-1 in_channel.
        block_list.append(conv_block)
        # modify the in_channel for the next stage layer.
        self.in_channel = plane * block.extension

        for _ in range(1, block_num):
            block_list.append(block(self.in_channel, plane, stride=1))

        return nn.Sequential(*block_list)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        # x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = nn.Softmax(dim=1)(x)
        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/434710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数字图像基础【7】应用线性回归最小二乘法(矩阵版本)求解几何变换(仿射、透视)

这一章主要讲图像几何变换模型,可能很多同学会想几何变换还不简单嚒?平移缩放旋转。在传统的或者说在同一维度上的基础变换确实是这三个,但是今天学习的是2d图像转投到3d拼接的基础变换过程。总共包含五个变换——平移、刚性、相似、仿射、透…

尚融宝10-Excel数据批量导入

目录 一、数据字典 (一)、什么是数据字典 (二)、数据字典的设计 二、Excel数据批量导入 (一)后端接口 1、添加依赖 2、创建Excel实体类 3、创建监听器 4、Mapper层批量插入 5、Service层创建监听…

2023年,想要靠做软件测试获得高薪,我还有机会吗?

时间过得很快,一眨眼,马上就要进入2023年了,到了年底,最近后台不免又出现了经常被同学问道这几个问题:2023年还能转行软件测试吗?零基础转行可行吗? 本期小编就“2023年,入行软件测…

一文解决nltk安装问题ModuleNotFoundError: No module named ‘nltk‘,保姆级教程

目录 问题一:No module named ‘nltk‘ 问题二:Please use the NLTK Downloader to obtain the resource 下载科学上网工具 问题三:套娃报错 如果会科学上网,可以直接看问题三 问题一:No module named ‘nltk‘ Mo…

【微服务笔记16】微服务组件之Gateway服务网关基础环境搭建

这篇文章,主要介绍微服务组件之Gateway服务网关基础环境搭建。 目录 一、Gateway服务网关 1.1、什么是Gateway 1.2、Gateway基础环境搭建 (1)基础环境介绍 (2)引入依赖 (3)添加路由配置信…

软件测试工程师的进阶之旅

很多人对测试工程师都有一些刻板印象,比如觉得测试“入门门槛低,没有技术含量”、“对公司不重要”、“操作简单工作枯燥”“一百个开发,一个测试”等等。 会产生这种负面评论,是因为很多人对测试的了解,还停留在几年…

Lesson12 udptcp协议

netstat命令->查看网络状态 n 拒绝显示别名,能显示数字的全部转化成数字l 仅列出有在 Listen (监听) 的服務状态p 显示建立相关链接的程序名t (tcp)仅显示tcp相关选项u (udp)仅显示udp相关选项a (all)显示所有选项,默认不显示LISTEN相关 pidof命令-&…

SQL select详解(基于选课系统)

表详情: 学生表: 学院表: 学生选课记录表: 课程表: 教师表: 查询: 1. 查全表 -- 01. 查询所有学生的所有信息 -- 方法一:会更复杂,进行了两次查询,第一…

机器学习笔记之正则化(六)批标准化(BatchNormalization)

机器学习笔记之正则化——批标准化[Batch Normalization] 引言引子:梯度消失梯度消失的处理方式批标准化 ( Batch Normalization ) (\text{Batch Normalization}) (Batch Normalization)场景构建梯度信息比例不平衡批标准化对于梯度比例不平衡的处理方式 ICS \text{…

《抄送列表》:过滤次要文件,优先处理重要文件

目录 一、题目 二、思路 1、查找字符/字符串方法:str1.indexOf( ) 2、字符串截取方法:str1.substring( ) 三、代码 详细注释版: 简化注释版: 一、题目 题目:抄送列表 题目链接:抄送列表 …

Java[集合] Map 和 Set

哈喽,大家好~ 我是保护小周ღ,本期为大家带来的是 Java Map 和 Set 集合详细介绍了两个集合的概念及其常用方法,感兴趣的朋友可以来学习一下。更多精彩敬请期待:保护小周ღ *★,*:.☆( ̄▽ ̄)/$:*.★* ‘ 一、…

JVM知识汇总

1、JVM架构图 2、Java编译器 Java编译器做的事情很简单,其实就是就是将Java的源文件转换为字节码文件。 1. 源文件存储的是高级语言的命令,JVM只认识"机器码"; 2. 因此将源文件转换为字节码文件,即是JVM看得懂的"…

Node.js—Buffer(缓冲器)

文章目录 1、概念2.、特点3、创建Buffer3.1 Buffer.alloc3.2 Buffer.allocUnsafe3.3 Buffer.from 4、操作Buffer4.1 Buffer 与字符串的转化4.2 Buffer 的读写 参考 1、概念 Buffer 是一个类似于数组的对象 ,用于表示固定长度的字节序列。Buffer 本质是一段内存空间…

视觉学习(四) --- 基于yolov5进行数据集制作和模型训练

环境信息 Jetson Xavier NX:Jetpack 4.4.1 Ubuntu:18.04 CUDA: 10.2.89 OpenCV: 4.5.1 cuDNN:8.0.0.180一.yolov5 项目代码整体架构介绍 1. yolov5官网下载地址: GitHub: https://github.com/ultralytics/yolov5/tree/v5.0 2. …

单元测试中的独立运行

单元测试中的独立运行 单元测试是针对代码单元的独立测试。要测试代码单元,首先要其使能够独立运行。项目中的代码具有依赖关系,例如,一个源文件可能直接或间接包含大量头文件,并调用众多其他源文件的代码,抽取其中的一…

论文阅读:Unsupervised Manifold Linearizing and Clustering

Author: Tianjiao Ding, Shengbang Tong, Kwan Ho Ryan Chan, Xili Dai, Yi Ma, Benjamin D. Haeffele Abstract 在本文中,我们建议同时执行聚类并通过最大编码率降低来学习子空间联合表示。 对合成和现实数据集的实验表明,所提出的方法实现了与最先进的…

limit、排序、分组单表查询(三)MySQL数据库(头歌实践教学平台)

文章目的初衷是希望学习笔记分享给更多的伙伴,并无盈利目的,尊重版权,如有侵犯,请官方工作人员联系博主谢谢。 目录 第1关:对查询结果进行排序 任务描述 相关知识 对查询结果排序 指定排序方向 编程要求 第2关&a…

浏览器架构和事件循环

浏览器架构 早期浏览器【单进程多线程】 Page Thread 页面渲染,负责执行js,plugin,drawNetWork Thread 网络请求其余线程 file, storage缺点:只要其中一个线程崩溃,页面就会崩溃。 现代浏览器架构 多进程的浏览器,浏览器的每一个…

几种常见的激活函数

文章目录 常见的激活函数介绍Sigmoid函数ReLU函数LeakyReLU函数Tanh函数Softmax函数总结 常见的激活函数介绍 激活函数是神经网络中的重要组成部分,它决定了神经元的输出。在神经网络的前向传播中,输入数据被传递给神经元,经过加权和和激活函…

Unity自动化打包(1)

一 安装Jenkins https://www.jenkins.io/download/ 官网 1) 使用 brew 安装 2) 安装完成后一般都会遇到问题 我用的是jenkins-lts 稳定版 解决办法 删除掉对应的文件夹 1 rm -rf /usr/local/Homebrew/Library/Taps/homebrew/homebrew-services 2…