ResNet与注意力机制:深度学习中的强强联合

news2025/4/23 1:21:48

引言

在深度学习领域,卷积神经网络(CNN)一直是图像处理任务的主流架构。然而,随着网络深度的增加,梯度消失和梯度爆炸问题逐渐显现,限制了网络的性能。为了解决这一问题,ResNet(残差网络)应运而生,通过引入残差连接,使得网络可以训练得更深,从而在多个视觉任务中取得了显著的效果。

然而,尽管ResNet在图像分类、目标检测等任务中表现出色,但在处理复杂场景时,仍然存在一些局限性。例如,网络可能会忽略一些重要的细节信息,或者对某些区域过度关注。为了进一步提升网络的性能,研究者们开始将注意力机制引入到ResNet中,通过自适应地调整特征图的重要性,使得网络能够更加关注于关键区域。

本文将详细介绍ResNet和注意力机制的基本原理,并探讨如何将两者结合,以提升网络的性能。我们还将通过代码示例,展示如何在实践中实现这一结合。

1. ResNet的基本原理

1.1 残差连接

ResNet的核心思想是引入残差连接(Residual Connection),即通过跳跃连接(Skip Connection)将输入直接传递到输出,使得网络可以学习残差映射,而不是直接学习原始映射。这种设计有效地缓解了梯度消失问题,使得网络可以训练得更深。

残差块(Residual Block)是ResNet的基本构建单元,其结构如下:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(residual)
        out = self.relu(out)
        return out

1.2 ResNet的网络结构

ResNet的网络结构由多个残差块堆叠而成,通常包括多个阶段(Stage),每个阶段包含多个残差块。随着网络的加深,特征图的尺寸逐渐减小,而通道数逐渐增加。常见的ResNet架构包括ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等。

2. 注意力机制的基本原理

2.1 注意力机制的概念

注意力机制(Attention Mechanism)最初在自然语言处理(NLP)领域中被提出,用于解决序列到序列(Seq2Seq)模型中的长距离依赖问题。其核心思想是通过计算输入序列中每个元素的重要性,动态地调整每个元素的权重,从而使得模型能够更加关注于关键信息。

在计算机视觉领域,注意力机制被广泛应用于图像分类、目标检测、图像分割等任务中。通过引入注意力机制,网络可以自适应地调整特征图的重要性,从而提升模型的性能。

2.2 常见的注意力机制

2.2.1 通道注意力机制

通道注意力机制(Channel Attention)通过计算每个通道的重要性,动态地调整每个通道的权重。常见的通道注意力机制包括SENet(Squeeze-and-Excitation Network)和CBAM(Convolutional Block Attention Module)等。

SENet的结构如下:

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
2.2.2 空间注意力机制

空间注意力机制(Spatial Attention)通过计算每个空间位置的重要性,动态地调整每个空间位置的权重。常见的空间注意力机制包括CBAM和Non-local Neural Networks等。

CBAM的结构如下:

class CBAMBlock(nn.Module):
    def __init__(self, channel, reduction=16, kernel_size=7):
        super(CBAMBlock, self).__init__()
        self.channel_attention = SEBlock(channel, reduction)
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.channel_attention(x)
        y = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
        y = self.spatial_attention(y)
        return x * y

3. ResNet与注意力机制的结合

3.1 为什么要在ResNet中引入注意力机制?

尽管ResNet通过残差连接有效地缓解了梯度消失问题,使得网络可以训练得更深,但在处理复杂场景时,仍然存在一些局限性。例如,网络可能会忽略一些重要的细节信息,或者对某些区域过度关注。通过引入注意力机制,网络可以自适应地调整特征图的重要性,从而更加关注于关键区域,提升模型的性能。

3.2 如何在ResNet中引入注意力机制?

在ResNet中引入注意力机制的方法有很多种,常见的方法包括在残差块中引入通道注意力机制、空间注意力机制,或者在网络的最后引入全局注意力机制等。

3.2.1 在残差块中引入通道注意力机制

在残差块中引入通道注意力机制的方法如下:

class ResidualBlockWithSE(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, reduction=16):
        super(ResidualBlockWithSE, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.se = SEBlock(out_channels, reduction)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)
        out += self.shortcut(residual)
        out = self.relu(out)
        return out
3.2.2 在残差块中引入空间注意力机制

在残差块中引入空间注意力机制的方法如下:

class ResidualBlockWithCBAM(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, reduction=16, kernel_size=7):
        super(ResidualBlockWithCBAM, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.cbam = CBAMBlock(out_channels, reduction, kernel_size)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.cbam(out)
        out += self.shortcut(residual)
        out = self.relu(out)
        return out

3.3 实验结果

通过在ResNet中引入注意力机制,网络的性能得到了显著提升。例如,在ImageNet数据集上,ResNet-50的Top-1准确率为76.15%,而引入SENet后,Top-1准确率提升至77.62%。类似地,引入CBAM后,Top-1准确率提升至77.98%。

4. 总结

本文详细介绍了ResNet和注意力机制的基本原理,并探讨了如何将两者结合,以提升网络的性能。通过在ResNet中引入注意力机制,网络可以自适应地调整特征图的重要性,从而更加关注于关键区域,提升模型的性能。实验结果表明,引入注意力机制后,ResNet的性能得到了显著提升。

未来,随着注意力机制的不断发展,我们可以期待更多创新的网络架构和训练方法,进一步提升深度学习模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2321142.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flutter项目之页面实现以及路由fluro

目录: 1、项目代码结构2、页面编写以及路由配置main.dart(入口文件)page_content.dartindex.dartapplication.dartpubspec.yamllogin.dartdio_http.dart 3、Fluro路由routes.dartnot_found_page.dart(路由优化,找不到页面时展示此页面) 4、注册页面 1、项…

《Python实战进阶》第31集:特征工程:特征选择与降维技术

第31集:特征工程:特征选择与降维技术 摘要 特征工程是机器学习和数据科学中不可或缺的一环,其核心目标是通过选择重要特征和降低维度来提升模型性能并减少计算复杂度。本集聚焦于特征选择与降维技术,涵盖过滤法、包裹法、嵌入法等…

C++类与对象的第二个简单的实战练习-3.24笔记

哔哩哔哩C面向对象高级语言程序设计教程(118集全) 实战二 Cube.h #pragma once class Cube { private:double length;double width;double height; public:double area(void);double Volume(void);//bool judgement(double L1, double W1, double H1);…

Rk3568驱动开发_设备树点亮LED_10

设备树中添加节点 在设备树文件中添加led节点,添加完后需要重新编译内核,因为单独编译这个设备树文件生成的dtb文件目前不能直接做替换,所以要编译内核将编译好的boot.img文件烧录到设备里,boot.img里包含新添加的设备树节点&…

Unity学习之Shader(Phong与Blinn-Phong)

三、Lesson3 1、关键名称 向量 • nDir:法线方向,点乘操作时简称n; • lDir:光照方向,点乘操作时简称l; • vDir:观察方向,点乘操作时简称v; • rDir:光反…

uniapp笔记-swiper组件实现轮播图

思路 主要就是参考 swiper | uni-app官网 实现轮播图。 实例 新建一个banner.vue通用组件。 代码如下&#xff1a; <template><view>轮播图</view> </template><script> </script><style> </style> 随后在index.vue中导…

【C++ 继承】—— 青花分水、和而不同,继承中的“明明德”与“止于至善”

欢迎来到ZyyOvO的博客✨&#xff0c;一个关于探索技术的角落&#xff0c;记录学习的点滴&#x1f4d6;&#xff0c;分享实用的技巧&#x1f6e0;️&#xff0c;偶尔还有一些奇思妙想&#x1f4a1; 本文由ZyyOvO原创✍️&#xff0c;感谢支持❤️&#xff01;请尊重原创&#x1…

FPGA_YOLO(二)

上述对cnn卷积神经网络进行介绍,接下来对YOLO进行总结,并研究下怎么在FPGA怎么实现的方案。 对于一个7*7*30的输出 拥有49个cell 每一个cell都有两个bbox两个框,并且两个框所包含的信息拥有30个 4个坐标信息和一个置信度5个,剩下就是20个类别。 FPGA关于YOLO的部署 1…

蓝桥杯学习-14子集枚举,二进制枚举

子集枚举 一、回溯3-子集枚举&#xff08;递归实现指数型枚举&#xff09; 一旦涉及选与不选&#xff0c;删和不删&#xff0c;留和不留-->两种状态-->就要想到子集枚举例题1–递归实现指数型枚举19685 其实看不懂这个题目&#xff0c;好奇怪的题目。根据老师的解析来写…

人工智能时代大学教育范式重构:基于AI编程思维的能力培养路径研究

人工智能技术的快速发展正在重塑高等教育的内容与方法。本文以AI编程教育为切入点&#xff0c;通过文献分析与案例研究&#xff0c;探讨AI时代大学教育的核心能力需求与教学范式转型路径。研究发现&#xff0c;AI编程中蕴含的系统性思维训练、项目架构能力和元认知能力培养机制…

<数据集>轨道异物识别数据集<目标检测>

数据集下载链接&#xff1a;https://download.csdn.net/download/qq_53332949/90527370 数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;1659张 标注数量(xml文件个数)&#xff1a;1659 标注数量(txt文件个数)&#xff1a;1659 标注类别数&#xff1a;6 标注类别…

Pyecharts功能详解与实战示例

一、Pyecharts简介 Pyecharts是一个基于Python的开源数据可视化库&#xff0c;它基于百度的Echarts库&#xff0c;提供了丰富的图表类型和强大的交互功能。通过Pyecharts&#xff0c;你可以轻松创建各种精美的图表&#xff0c;如折线图、柱状图、饼图、散点图、地图等&#xf…

EasyUI数据表格中嵌入下拉框

效果 代码 $(function () {// 标记当前正在编辑的行var editorIndex -1;var data [{code: 1,name: 1,price: 1,status: 0},{code: 2,name: 2,price: 2,status: 1}]$(#dg).datagrid({data: data,onDblClickCell:function (index, field, value) {var dg $(this);if(field ! …

C语言:扫雷

在编程的世界里&#xff0c;扫雷游戏是一个经典的实践项目。它不仅能帮助我们巩固编程知识&#xff0c;还能锻炼逻辑思维和解决问题的能力。今天&#xff0c;就让我们一起用 C 语言来实现这个有趣的游戏&#xff0c;并且通过图文并茂的方式&#xff0c;让每一步都清晰易懂 1. 游…

操作系统必知的面试题

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家&#xff0c;历代文学网&#xff08;PC端可以访问&#xff1a;https://literature.sinhy.com/#/literature?__c1000&#xff0c;移动端可微信小程序搜索“历代文学”&#xff09;总架构师&#xff0c;15年工作经验&#xff0c;…

清华大学.智灵动力-《DeepSeek行业应用实践报告》附PPT下载方法

导 读INTRODUCTION 今天分享是由清华大学.智灵动力&#xff1a;《DeepSeek行业应用实践报告》&#xff0c;主要介绍了DeepSeek模型的概述、优势、使用技巧、与其他模型的对比&#xff0c;以及在多个行业中的应用和未来发展趋势。为理解DeepSeek模型的应用和未来发展提供了深入的…

可视化图解算法:链表的奇偶重排(排序链表)

1. 题目 描述 给定一个单链表&#xff0c;请设定一个函数&#xff0c;将链表的奇数位节点和偶数位节点分别放在一起&#xff0c;重排后输出。 注意是节点的编号而非节点的数值。 数据范围&#xff1a;节点数量满足 0≤n≤105&#xff0c;节点中的值都满足 0≤val≤10000 要…

SAP Activate Methodology in a Nutshell Phases of SAP Activate Methodology

SAP Activate Methodology in a Nutshell Phases of SAP Activate Methodology

开源AI大模型、AI智能名片与S2B2C商城小程序源码:实体店引流的破局之道

摘要&#xff1a;本文聚焦实体店引流困境&#xff0c;提出基于"开源AI大模型AI智能名片S2B2C商城小程序源码"的技术整合方案。通过深度解析各技术核心机制与协同逻辑&#xff0c;结合明源云地产营销、杭州美甲店裂变等实际案例&#xff0c;论证其对流量精准获取、客户…

JVM 02

今天是2025/03/23 19:07 day 10 总路线请移步主页Java大纲相关文章 今天进行JVM 3,4 个模块的归纳 首先是JVM的相关内容概括的思维导图 3. 类加载机制 加载过程 加载&#xff08;Loading&#xff09; 通过类全限定名获取类的二进制字节流&#xff08;如从JAR包、网络、动态…