经典网络解析(四) ResNet | 残差模块,网络结构代码实现全解析

news2024/12/23 19:15:19

文章目录

  • 1 设计初衷
  • 2.网络结构
    • 2.1 残差块
    • 2.2 中间的卷积网络特征提取块
      • 1 两层3×3卷积层
      • 2 先1×1卷积层,再3×3卷积层,再3×3卷积层
    • 2.3 结构总览表格
  • 3 为什么残差模块有效?
    • 3.1 前向传播
    • 3.2 反向传播
    • 3.3 恒等映射
    • 3.4 集成模型
  • 4.代码实现

1 设计初衷

我们之前讲了VGG等网络,在之前网络的研究中,研究者感觉

网络越深,分类准确率越高,但是随着网络的加深,科学家们发现分类准确率反而会下降,无论是在训练集上还是测试集上。

ResNet的作者团队发现了这种现象的真正原因是:

训练过程中网络的正、反向的信息流动不顺畅,网络没有被充分训练,他们称之为“退化”

2.网络结构

2.1 残差块

解决方式:

​ 构建了残差模块,通过堆叠残差模块,可以构建任意深度的神经网络,而不会出现退化的现象

​ 提出了批归一化对抗梯度消失,该方法降低了网络训练过程中对于权重初始化的依赖

H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x

我们网络要学习的是F(x)

F ( x ) = H ( x ) − x F(x)=H(x)-x F(x)=H(x)x

F(x)实际上就是输出与输入的差异,所以叫做残差模块

在这里插入图片描述

2.2 中间的卷积网络特征提取块

中间的块有两种可能

1 两层3×3卷积层

在这里插入图片描述

2 先1×1卷积层,再3×3卷积层,再3×3卷积层

第一个用了1×1把卷积通道降下去(减少运算量),第二个用了1×1把卷积通道再升上去(便于和输入x连接)

在这里插入图片描述

2.3 结构总览表格

表中展示了18层,34层,50层,101层,152层的ResNet的结构

[ ]方括号中即是我们上面讲的两个特征提取块,×几代表堆叠几个

最后经过一个全局平均池化,一个全连接层

在这里插入图片描述

3 为什么残差模块有效?

3.1 前向传播

1 前向传播过程中重要信息不消失

通过残差网络的设计,我们可以理解为原来的有机会信息可以维持不变,对分类有帮助的信息得到加强

能够避免卷积层堆叠存在的信息丢失

3.2 反向传播

2 反向传播中梯度可以控制不消失

在典型的残差模块中,输入数据被分为两部分,一部分通过一个或多个神经层进行变换,而另一部分直接传递给输出。这个直接传递的部分是输入数据的恒等映射,即没有变换。这意味着至少一部分的信息在经过神经网络之后保持不变。

当进行反向传播以更新神经网络参数时,梯度是根据损失函数计算的。在传统的深度神经网络中,由于多层的网络梯度相乘,梯度可以逐渐变小并导致梯度消失。但在残差模块中,由于存在恒等映射,至少一部分梯度可以直接通过跳过变换的路径传播,而不会受到变换的影响。

H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x

比如这个式子对x求偏导 ∂ F / ∂ x + 1 ∂F/∂x+1 F/x+1

这时候保证了梯度至少会加1 不让梯度连乘逐渐变小

3.3 恒等映射

3,可以理解为当网络变深之后,非线性变得很强,网络很难学会简单的恒等映射,残差模块可以解决这个问题

3.4 集成模型

4 残差网络可以看做一种集成模型

可以看做很多简单或复杂的子网络的组合求和!!!

在这里插入图片描述

但是这样可能会造成冗余,因为其中还可能会有很多不需要的信息,这便是后来的DenseNet,会让速度提升

4.代码实现

实现的是

import torch
from torch import nn
from torch.nn import functional as F

class Residual_block(nn.Module):
    def __init__(self,input_channels,output_channels,first=False):
        super().__init__()
        self.first=first
        if first==True:
            self.conv1=nn.Conv2d(input_channels,output_channels,stride=2,kernel_size=3,padding=1)
            self.conv3=nn.Conv2d(input_channels,output_channels,kernel_size=1,stride=2)
        else:
            self.conv1=nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1)
        self.bn1=nn.BatchNorm2d(output_channels)
        self.conv2=nn.Conv2d(output_channels,output_channels,kernel_size=3,padding=1)
        self.bn2=nn.BatchNorm2d(output_channels)
        
    def forward(self,x):
        Y=F.relu(self.bn1(self.conv1(x)))
        Y=self.bn2(self.conv2(Y))
        if self.first==True:
            x=self.conv3(x)
        Y=x+Y
        return F.relu(Y)
def resnet_block(input_channels,output_channels,num_residual_block,special=False):
    blk=[]
    for i in range(num_residual_block):
        if i==0 and special==True:
            blk.append(Residual_block(input_channels,input_channels))
        if i==0 and special==False:
            blk.append(Residual_block(input_channels,output_channels,first=True))
        else:
            blk.append(Residual_block(output_channels,output_channels))
    return blk


b1=nn.Sequential(
    nn.Conv2d(kernel_size=7,in_channels=3,out_channels=64,stride=2,padding=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
)
R1=Residual_block(64,64)
x=torch.ones(1,3,224,224)
for layer in b1:
    x=layer(x)

b18_2=nn.Sequential(*resnet_block(64,64,2,special=True))
b18_3=nn.Sequential(*resnet_block(64,128,2))
b18_4=nn.Sequential(*resnet_block(128,256,2))
b18_5=nn.Sequential(*resnet_block(256,512,2))

b34_2=nn.Sequential(*resnet_block(64,64,3,special=True))
b34_3=nn.Sequential(*resnet_block(64,128,4))
b34_4=nn.Sequential(*resnet_block(128,256,6))
b34_5=nn.Sequential(*resnet_block(256,512,3))


#Resnet-18
Resnet_18=nn.Sequential(b1,b18_2,b18_3,b18_4,b18_5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(512,10))
#Resnet-34
Resnet_18=nn.Sequential(b1,b34_2,b34_3,b34_4,b34_5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(512,10))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1040419.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何学习嵌入式Linux?

今日话题,如何学习嵌入式Linux?嵌入式底层开发是一种重要的技术,它被广泛应用于各种嵌入式系统中。随着科技的不断发展,嵌入式系统已经成为了我们日常生活中不可或缺的一部分。这就使得嵌入式开发的重要性也凸显出来。刚好我这有一…

opencv for unity package在unity中打开相机不需要dll

下载OpenCV for Unity 导入后,里面有很多案例 直接打开就可以运行 打开相机

Linux: errno: EADDRNOTAVAIL; ipv6-bind;Cannot assign requested address

文章目录 解释一种情况tentative 的解释 解释一种情况 #define EADDRNOTAVAIL 99 /* Cannot assign requested address */ 有一种情况是:当IP6的地址处于tentative的时候,就会返回这个错误。下面的是bind的调用的时候会check地址的flag。如果是tentati…

Java学习星球,十月集训,五大赛道(文末送书)

目录 什么是知识星球?我的知识星球能为你提供什么?专属专栏《Java基础教程系列》内容概览:《Java高并发编程实战》、《MySQL 基础教程系列》内容概览:《微服务》、《Redis中间件》、《Dubbo高手之路》、《华为OD机试》内容概览&am…

一款Python认证和授权的利器

迷途小书童 读完需要 7分钟 速读仅需 3 分钟 1 简介 authlib 是一个开源的 Python 库,旨在提供简单而强大的认证和授权解决方案。它支持多种认证和授权协议,如 OAuth、OpenID Connect 和 JWT。authlib 具有灵活的架构和丰富的功能,使开发人员…

lS1028 + 六网口TSN 硬交换+QNX/Linux实时系统解决方案在轨道交通系统的应用

lS1028 六网口TSN 硬交换QNX/Linux实时系统解决方案在轨道交通系统的应用 以下是在轨道交通应用的实物: CPUNXP LS1028A架构双核Cortex-A72主频1.5GHzRAM2GB DDR4ROM8GB eMMCOSUbuntu20.04供电DC 12V工作温度-40℃~ 80℃ 功能数量参数Display Port≤1路支持DP1.3…

UniAccess Agent卸载

异常场景: UniAccess Agent导致系统中的好多设置打不开 例如:ipv4的协议,注册表,host等等 需要进行删除,亲测有效,及多家答案平凑的 借鉴了这位大神及他里面引用的大神的内容 https://blog.csdn.net/weixin_44476410/article/details/121605455 问题描述 这个进…

Android开发MVP架构记录

Android开发MVP架构记录 安卓的MVP(Model-View-Presenter)架构是一种常见的软件设计模式,用于帮助开发者组织和分离应用程序的不同组成部分。MVP架构的目标是将应用程序的业务逻辑(Presenter)、用户界面(V…

由于找不到msvcr110.dll的5种解决方法

在使用电脑的过程中,我们可能会遇到一些问题,比如打开软件时提示找不到 msvcr110.dll 文件丢失。这通常意味着该文件已被删除或损坏,导致程序无法正常运行。本文将介绍几种解决方案,帮助您解决这个问题。 首先,我们需…

Linxu下c语言实现socket+openssl数据传输加密

文章目录 1. Socket连接建立流程2、SocketSSL的初始化流程3、初始化SSL环境,证书和密钥4、SocketSSL 的c语言实现4.1 编写SSL连接函数4.2 编写加密服务端server.c4.3 编写加密客户端client.c 5、使用tcpdump检验源码获取 在进行网络编程的时候,我们通常使…

分布式算法相关,使用Redis落地解决1-2亿条数据缓存

面试题:1~2亿数据需要缓存,请问如何设计个存储案例 回答:单机单台100%不可能,肯定是分布式存储,用redis如何落地? 一般业界有三种解决方案: 哈希取余分区 2亿条记录就是2亿个k,v&…

Linux学习-HIS部署(3)

Jenkins插件资源下载 Jenkins部署 Jenkins部署 #Jenkins主机安装OpenJDK环境 [rootJenkins ~]# yum clean all; yum repolist -v ... Total packages: 8,265 [rootJenkins ~]# yum -y install java-11-openjdk-devel.x86_64 #安装OpenJDK11 [rootJenkins ~]# ln -s /usr/l…

Css 美化滚动条

/*设置滚动条宽度为 6px*/ ::-webkit-scrollbar {width: 6px; } /*设置背景颜色,并设置边框倒角,设置滚动动画,0.2 */ ::-webkit-scrollbar-thumb {background-color: #0003;border-radius: 10px;transition: all .2s ease-in-out; } /*设置滚…

探索创意的新辅助,AI与作家的完美合作

在现代社会,文学创作一直是人类精神活动中的重要一环。从古典文学到现代小说,从诗歌到戏剧,作家们以他们的独特视角和文学天赋为我们展示了丰富多彩的人生世界。而近年来,人工智能技术的快速发展已经渗透到各行各业,文…

Leetcode191. 位1的个数

力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 编写一个函数,输入是一个无符号整数(以二进制串的形式),返回其二进制表达式中数字位数为 1 的个数(也被称为汉明重量)。 思路&…

人工智能AI 全栈体系(六)

第一章 神经网络是如何实现的 这些年神经网络的发展越来越复杂,应用领域越来越广,性能也越来越好,但是训练方法还是依靠 BP 算法。也有一些对 BP 算法的改进算法,但是大体思路基本是一样的,只是对 BP 算法个别地方的一…

mapper文件添加@Mapper注解爆红

如图所示 报错原因&#xff1a;缺少相关的依赖 <dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><version>2.2.2</version> </dependency> 添加之后并刷新依赖…

C++: stack 与 queue

目录 1.stack与queue stack queue 2.priority_queue 2.1相关介绍 2.2模拟实现priority_queue --仿函数: --push --pop --top --size --empty --迭代器区间构造 2.3仿函数 3.容器适配器 stack模拟实现 queue模拟实现 学习目标: 1.stack和queue介绍与使用 2.pri…

Toaster - Android 吐司框架,专治 Toast 各种疑难杂症

官网 https://github.com/getActivity/Toaster 这可能是性能优、使用简单&#xff0c;支持自定义&#xff0c;不需要通知栏权限的吐司 想了解实现原理的可以点击此链接查看&#xff1a;Toaster 源码 集成步骤 如果你的项目 Gradle 配置是在 7.0 以下&#xff0c;需要在 bui…

如何使用Docker安装最新版本的Redis并设置远程访问(含免费可视化工具)

文章目录 安装Docker安装Redisredis.conf文件远程访问Redis免费可视化工具相关链接Docker是一种开源的应用容器引擎,使用Docker可以让我们快速部署应用环境,本文介绍如何使用Docker安装最新版本的Redis。 安装Docker 首先需要安装Docker,具体的安装方法可以参考Docker官方文…