深度学习中的残差网络、加权残差连接(WRC)与跨阶段部分连接(CSP)详解

news2024/12/29 10:10:25

随着深度学习技术的不断发展,神经网络架构变得越来越复杂,而这些复杂网络在训练时常常遇到梯度消失、梯度爆炸以及计算效率低等问题。为了克服这些问题,研究者们提出了多种网络架构,包括 残差网络(ResNet)加权残差连接(WRC)跨阶段部分连接(CSP)

本文将详细介绍这三种网络架构的基本概念、工作原理以及如何在 PyTorch 中实现它们。我们会通过代码示例来展示每个技术的实现方式,并重点讲解其中的核心部分。

目录

一、残差网络(ResNet)

1.1 残差网络的背景与原理

1.2 残差块的实现

重点

二、加权残差连接(WRC)

2.1 WRC的提出背景

2.2 WRC的实现

重点

三、跨阶段部分连接(CSP)

3.1 CSP的提出背景

3.2 CSP的实现

重点

四、总结


一、残差网络(ResNet)

1.1 残差网络的背景与原理

有关于残差网络,详情可以查阅以下博客,更为详细与新手向:

YOLO系列基础(三)从ResNet残差网络到C3层-CSDN博客

深层神经网络的训练常常遭遇梯度消失或梯度爆炸的问题,导致训练效果不好。为了解决这一问题,微软的何凯明等人提出了 残差网络(ResNet),引入了“跳跃连接(skip connections)”的概念,使得信息可以直接绕过某些层传播,从而避免了深度网络训练中的问题。

在传统的神经网络中,每一层都试图学习输入到输出的映射。但在 ResNet 中,网络不再直接学习从输入到输出的映射,而是学习输入与输出之间的“残差”,即

H(x) = F(x) + x

其中 F(x) 是网络学到的残差部分,x 是输入。

这种方式显著提升了网络的训练效果,并且让深层网络的训练变得更加稳定。

1.2 残差块的实现

下面是一个简单的残差块实现,它包括了两层卷积和一个跳跃连接。跳跃连接帮助保持梯度的流动,避免深层网络中的梯度消失问题。

图例如下:

代码示例如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义残差块
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 如果输入和输出的通道数不同,则使用1x1卷积调整尺寸
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))  # 第一层卷积后激活
        out = self.bn2(self.conv2(out))        # 第二层卷积
        out += self.shortcut(x)                # 残差连接
        return F.relu(out)                     # ReLU激活

# 构建ResNet
class ResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet, self).__init__()
        self.layer1 = ResidualBlock(3, 64)
        self.layer2 = ResidualBlock(64, 128)
        self.layer3 = ResidualBlock(128, 256)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))  # 全局平均池化
        x = torch.flatten(x, 1)                # 展平
        x = self.fc(x)                         # 全连接层
        return x

# 示例:构建一个简单的 ResNet
model = ResNet(num_classes=10)
print(model)
重点
  1. 残差连接的实现:在 ResidualBlock 类中,out += self.shortcut(x) 实现了输入与输出的加法操作,这是残差学习的核心。
  2. 处理输入和输出通道数不一致的情况:如果输入和输出的通道数不同,通过使用 1x1 卷积调整输入的维度,确保加法操作能够进行。

二、加权残差连接(WRC)

2.1 WRC的提出背景

传统的残差网络通过简单的跳跃连接将输入和输出相加,但在某些情况下,不同层的输出对最终结果的贡献是不同的。为了让网络更灵活地调整各层贡献,加权残差连接(WRC) 引入了可学习的权重。公式如下

H(x) =\alpha F(x) + \beta x

其中 F(x) 是网络学到的残差部分,x 是输入,\alpha 和 \beta是权重。

WRC通过为每个残差连接引入可学习的权重 \alpha\beta,使得网络能够根据任务需求自适应地调整每个连接的贡献。

2.2 WRC的实现

以下是 WRC 的实现代码,我们为每个残差连接引入了权重参数 alphabeta,这些参数通过训练进行优化。

图例如下:

可以看到,加权残差快其实就是给残差网络的两条分支加个权而已 

代码示例如下: 

class WeightedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(WeightedResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 权重初始化
        self.alpha = nn.Parameter(torch.ones(1))  # 可学习的权重
        self.beta = nn.Parameter(torch.ones(1))   # 可学习的权重

        # 如果输入和输出的通道数不同,则使用1x1卷积调整尺寸
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # 加权残差连接:使用可学习的权重 alpha 和 beta
        out = self.alpha * out + self.beta * self.shortcut(x)
        return F.relu(out)

# 示例:构建一个加权残差块
model_wrc = WeightedResidualBlock(3, 64)
print(model_wrc)
重点
  1. 可学习的权重 alphabeta:我们为残差块中的两个加法项(即残差部分和输入部分)引入了可学习的权重。通过训练,这些权重可以自动调整,使网络能够根据任务需求更好地融合输入和输出。

  2. 加权残差连接的实现:在 forward 方法中,out = self.alpha * out + self.beta * self.shortcut(x) 表示加权残差连接,其中 alphabeta 是可学习的参数。

三、跨阶段部分连接(CSP)

3.1 CSP的提出背景

虽然 ResNet 和 WRC 提供了有效的残差学习和信息融合机制,但在一些更复杂的网络中,信息的传递依然面临冗余和计算开销较大的问题。为了解决这一问题,跨阶段部分连接(CSP) 提出了更加高效的信息传递方式。CSP通过选择性地传递部分信息而不是所有信息,减少了计算量并保持了模型的表达能力。

3.2 CSP的实现

CSP通过分割输入特征,并在不同阶段进行不同的处理,从而减少冗余的信息传递。下面是 CSP 的实现代码。

CSP思想图例如下:

特征分割(Feature Splitting):CSP通过分割输入特征图,并将分割后的特征图分别送入不同的子网络进行处理。一般来说,一条分支的子网络会比较简单,一条分支的自网络则是原来主干网络的一部分。

重点
  1. 部分特征选择性连接:将输入特征分为两部分。每部分特征单独经过卷积处理后,通过 torch.cat() 进行拼接,形成最终的输出。
  2. 跨阶段部分连接:CSP块通过分割输入特征并在不同阶段处理,有效地减少了计算开销,并且保持了网络的表达能力。

四、总结

本文介绍了 残差网络(ResNet)加权残差连接(WRC)跨阶段部分连接(CSP) 这三种网络架构。

finally,求赞求赞求赞~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2267394.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Excel将混乱的多行做成1列

目标是将数据按从左到右,再从上到下排成一列。 公式法 首先用textjoin函数将文本包起来,做成一个超长文本。 然后用公式 截取文本 Mid(m1,n,3),意思就是对m1单元格,从第n个字符开始,截取3个字符出来。 这个公式如何自…

在vscode的ESP-IDF中使用自定义组件

以hello-world为例,演示步骤和注意事项 1、新建ESP-IDF项目 选择模板 从hello-world模板创建 2、打开项目 3、编译结果没错 正在执行任务: /home/azhu/.espressif/python_env/idf5.1_py3.10_env/bin/python /home/azhu/esp/v5.1/esp-idf/tools/idf_size.py /home…

基于springboot校园招聘系统源码和论文

可做计算机毕业设计JAVA、PHP、爬虫、APP、小程序、C#、C、python、数据可视化、大数据、文案 使用旧方法对校园招聘系统的信息进行系统化管理已经不再让人们信赖了,把现在的网络信息技术运用在校园招聘系统的管理上面可以解决许多信息管理上面的难题,比…

PaddleOCR文字识别模型的FineTune

一、paddleOCR paddle框架为百度开发的深度学习框架,其中对于文字检测、识别具有较为便利的开发条件。同时PaddleOCR文字识别工具较为轻量化,并可按照任务需求进行model的finetune,满足实际的业务需求。 源码来源:githubOCR 在gi…

【数据库初阶】Ubuntu 环境安装 MySQL

🎉博主首页: 有趣的中国人 🎉专栏首页: 数据库初阶 🎉其它专栏: C初阶 | C进阶 | 初阶数据结构 小伙伴们大家好,本片文章将会讲解 Ubuntu 系统安装 MySQL 的相关内容。 如果看到最后您觉得这篇…

MoH:将多头注意力(Multi-Head Attention)作为头注意力混合(Mixture-of-Head Attention)

摘要 https://arxiv.org/pdf/2410.11842? 在本文中,我们对Transformer模型的核心——多头注意力机制进行了升级,旨在提高效率的同时保持或超越先前的准确度水平。我们表明,多头注意力可以表示为求和形式。鉴于并非所有注意力头都具有同等重…

AI助力古诗视频制作全流程化教程

AI助力古诗视频制作全流程化教程 目录 1. 制作视频的原材料(全自动) 2.文生图:图像生成(手动) 3.文生音频:TTS技术(全自动) 4.视频编辑(手动) 5.自动发…

基于SSM的“快递管理系统”的设计与实现(源码+数据库+文档+PPT)

基于SSM的“快递管理系统”的设计与实现(源码数据库文档PPT) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 登陆页面 注册页面 快递员页面 派单员订单管理页面 派单员订单添…

AWTK 在全志 tina linux 上支持 2D 图形加速

全志 tina linux 2D 图形加速插件。 开发环境为 全志 Tina Linux 虚拟机。 1. 准备 下载 awtk git clone https://github.com/zlgopen/awtk.git下载 awtk-linux-fb git clone https://github.com/zlgopen/awtk-linux-fb.git下载 awtk-tina-g2d git clone https://github.co…

Unity游戏环境交互系统

概述 交互功能使用同一个按钮或按钮列表,在不同情况下显示不同的内容,按下执行不同的操作。 按选项个数分类 环境交互系统可分为两种,单选项交互,一般使用射线检测;多选项交互,一般使用范围检测。第一人…

线性直流电流

电阻网络的等效 等效是指被化简的电阻网络与等效电阻具有相同的 u-i 关系 (即端口方程),从而用等效电阻代替电阻网络之后,不 改变其余部分的电压和电流。 串联等效: 并联等效: 星角变换 若这两个三端网络是等效的,从任…

攻防世界web第二题unseping

这是题目 <?php highlight_file(__FILE__);class ease{private $method;private $args;function __construct($method, $args) {$this->method $method;$this->args $args;}function __destruct(){if (in_array($this->method, array("ping"))) {cal…

[文献阅读]ReAct: Synergizing Reasoning and Acting in Language Models

文章目录 摘要Abstract:思考与行为协同化Reason(Chain of thought)ReAct ReAct如何协同推理 响应Action&#xff08;动作空间&#xff09;协同推理 结果总结 摘要 ReAct: Synergizing Reasoning and Acting in Language Models [2210.03629] ReAct: Synergizing Reasoning an…

ISDP010_基于DDD架构实现收银用例主成功场景

信息系统开发实践 &#xff5c; 系列文章传送门 ISDP001_课程概述 ISDP002_Maven上_创建Maven项目 ISDP003_Maven下_Maven项目依赖配置 ISDP004_创建SpringBoot3项目 ISDP005_Spring组件与自动装配 ISDP006_逻辑架构设计 ISDP007_Springboot日志配置与单元测试 ISDP008_SpringB…

Linux -- 从抢票逻辑理解线程互斥

目录 抢票逻辑代码&#xff1a; thread.hpp thread.cc 运行结果&#xff1a; 为什么票会抢为负数&#xff1f; 概念前言 临界资源 临界区 原子性 数据不一致 为什么数据不一致&#xff1f; 互斥 概念 pthread_mutex_init&#xff08;初始化互斥锁&#xff09; p…

1.微服务灰度发布落地实践(方案设计)

前言 微服务架构中的灰度发布&#xff08;也称为金丝雀发布或渐进式发布&#xff09;是一种在不影响现有用户的情况下&#xff0c;逐步将新版本的服务部署到生产环境的策略。通过灰度发布&#xff0c;你可以先将新版本的服务暴露给一小部分用户或特定的流量&#xff0c;观察其…

从 Coding (Jenkinsfile) 到 Docker:全流程自动化部署 Spring Boot 实战指南(简化篇)

前言 本文记录使用 Coding (以 Jenkinsfile 为核心) 和 Docker 部署 Springboot 项目的过程&#xff0c;分享设置细节和一些注意问题。 1. 配置服务器环境 在实施此过程前&#xff0c;确保服务器已配置好 Docker、MySQL 和 Redis&#xff0c;可参考下列链接进行操作&#xff1…

丢失的MD5

丢失的MD5 源代码&#xff1a; import hashlib for i in range(32,127):for j in range(32,127):for k in range(32,127):mhashlib.md5()m.update(TASCchr(i)O3RJMVchr(j)WDJKXchr(k)ZM)desm.hexdigest()if e9032 in des and da in des and 911513 in des:print des 发现给…

基于51单片机的交通灯外部中断proteus仿真

地址&#xff1a; https://pan.baidu.com/s/1WSlta_7pz5HdWsyIGoviHg 提取码&#xff1a;1234 仿真图&#xff1a; 芯片/模块的特点&#xff1a; AT89C52/AT89C51简介&#xff1a; AT89C52/AT89C51是一款经典的8位单片机&#xff0c;是意法半导体&#xff08;STMicroelectro…

JavaWeb(一) | 基本概念(web服务器、Tomcat、HTTP、Maven)、Servlet 简介

1. 基本概念 1.1、前言 web开发&#xff1a; web&#xff0c;网页的意思&#xff0c;www.baidu.com静态 web html,css提供给所有人看的数据始终不会发生变化&#xff01; 动态 web 淘宝&#xff0c;几乎是所有的网站&#xff1b;提供给所有人看的数据始终会发生变化&#xf…