SK-Net eca注意力机制应用于ResNet (附代码)

news2025/1/12 17:50:30

resnet发展历程

论文地址:https://arxiv.org/pdf/1903.06586.pdf

代码地址:https://github.com/pppLang/SKNet

1.是什么?

SK-net网络是一种增加模块嵌入到一些网络中的注意力机制,它可以嵌入和Resnet中进行补强,嵌入和方法和Se-net类似。SKNet的全称是“Selective Kernel Network”,它和SENet都是一个团队提出来的。SENet是对通道执行注意力机制,而SKNet则是对卷积核执行注意力机制,即让网络自己选择合适的卷积核。SKNet的核心模块包括多分支卷积网络、组卷积、空洞卷积以及注意力机制。在该模块中,作者使用了多尺寸的卷积核,让网络自己选择合适的尺度。SK-net网络的优点是可以提高网络的准确性和泛化能力,同时减少了网络的参数数量和计算量。

2.为什么?

SKNet的设计在以下几个方面具有优势:

多尺度信息融合
通过选择性地应用不同尺度的卷积核,SKNet能够有效地融合多尺度的特征信息。这有助于网络捕捉不同层次的视觉特征,提高了特征的表征能力。

自适应性
选择模块使网络能够自适应地选择卷积核的尺度,从而适应不同任务和图像的特点。这种自适应性能够使网络在各种场景下都能表现出色。

减少计算成本
尽管引入了多尺度卷积核,但由于选择模块的存在,SKNet只会选择一部分卷积核进行计算,从而减少了计算成本,保持了网络的高效性。
 

3.怎么样?

3.1网络模型

SK-net的整个流程如下:
(1)Split划分
    首先使用不同的卷积核获得两个不同的特征图。
(2)Fuse融合
    融合两个不同的特征图,然后将它们通过全局平均池化的方式提取出每一个通道的融合值。再通过线性层进行通过压缩,这里和Se-net的通道注意力是一样的。
(3)Select选择
    在通过一个Softmax层进行选择,后续的操作和Se-net是一样的,得到通过挑选的特征图。这里的特征图可以看到每一个通道所使用的卷积核是不一样的,也是通过学习得到的。

3.2 网络结构


 

3.3 代码实现

SKConv

import torch.nn as nn
import torch

class SKConv(nn.Module):
    def __init__(self, features, WH, M, G, r, stride=1, L=32):
        super(SKConv, self).__init__()
        d = max(int(features / r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        for i in range(M):
            # 使用不同kernel size的卷积
            self.convs.append(
                nn.Sequential(
                    nn.Conv2d(features,
                              features,
                              kernel_size=3 + i * 2,
                              stride=stride,
                              padding=1 + i,
                              groups=G), nn.BatchNorm2d(features),
                    nn.ReLU(inplace=False)))
            
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(nn.Linear(d, features))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        for i, conv in enumerate(self.convs):
            fea = conv(x).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)
        fea_U = torch.sum(feas, dim=1)
        fea_s = fea_U.mean(-1).mean(-1)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            print(i, fea_z.shape)
            vector = fc(fea_z).unsqueeze_(dim=1)
            print(i, vector.shape)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector],
                                              dim=1)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v

if __name__ == "__main__":
    t = torch.ones((32, 256, 24,24))
    sk = SKConv(256,WH=1,M=2,G=1,r=2)
    out = sk(t)
    print(out.shape)

SKNet

import torch.nn as nn
import torch
from functools import reduce
class SKConv(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1,M=2,r=16,L=32):
        '''
        :param in_channels:  输入通道维度
        :param out_channels: 输出通道维度   原论文中 输入输出通道维度相同
        :param stride:  步长,默认为1
        :param M:  分支数
        :param r: 特征Z的长度,计算其维度d 时所需的比率(论文中 特征S->Z 是降维,故需要规定 降维的下界)
        :param L:  论文中规定特征Z的下界,默认为32
        '''
        super(SKConv,self).__init__()
        d=max(in_channels//r,L)   # 计算向量Z 的长度d
        self.M=M
        self.out_channels=out_channels
        self.conv=nn.ModuleList()  # 根据分支数量 添加 不同核的卷积操作
        for i in range(M):
            # 为提高效率,原论文中 扩张卷积5x5为 (3X3,dilation=2)来代替。 且论文中建议组卷积G=32
            self.conv.append(nn.Sequential(nn.Conv2d(in_channels,out_channels,3,stride,padding=1+i,dilation=1+i,groups=32,bias=False),
                                           nn.BatchNorm2d(out_channels),
                                           nn.ReLU(inplace=True)))
        self.global_pool=nn.AdaptiveAvgPool2d(1) # 自适应pool到指定维度    这里指定为1,实现 GAP
        self.fc1=nn.Sequential(nn.Conv2d(out_channels,d,1,bias=False),
                               nn.BatchNorm2d(d),
                               nn.ReLU(inplace=True))   # 降维
        self.fc2=nn.Conv2d(d,out_channels*M,1,1,bias=False)  # 升维
        self.softmax=nn.Softmax(dim=1) # 指定dim=1  使得两个全连接层对应位置进行softmax,保证 对应位置a+b+..=1
    def forward(self, input):
        batch_size=input.size(0)
        output=[]
        #the part of split
        for i,conv in enumerate(self.conv):
            #print(i,conv(input).size())
            output.append(conv(input))
        #the part of fusion
        U=reduce(lambda x,y:x+y,output) # 逐元素相加生成 混合特征U
        s=self.global_pool(U)
        z=self.fc1(s)  # S->Z降维
        a_b=self.fc2(z) # Z->a,b 升维  论文使用conv 1x1表示全连接。结果中前一半通道值为a,后一半为b
        a_b=a_b.reshape(batch_size,self.M,self.out_channels,-1) #调整形状,变为 两个全连接层的值
        a_b=self.softmax(a_b) # 使得两个全连接层对应位置进行softmax
        #the part of selection
        a_b=list(a_b.chunk(self.M,dim=1))#split to a and b   chunk为pytorch方法,将tensor按照指定维度切分成 几个tensor块
        a_b=list(map(lambda x:x.reshape(batch_size,self.out_channels,1,1),a_b)) # 将所有分块  调整形状,即扩展两维
        V=list(map(lambda x,y:x*y,output,a_b)) # 权重与对应  不同卷积核输出的U 逐元素相乘
        V=reduce(lambda x,y:x+y,V) # 两个加权后的特征 逐元素相加
        return V
class SKBlock(nn.Module):
    '''
    基于Res Block构造的SK Block
    ResNeXt有  1x1Conv(通道数:x) +  SKConv(通道数:x)  + 1x1Conv(通道数:2x) 构成
    '''
    expansion=2 #指 每个block中 通道数增大指定倍数
    def __init__(self,inplanes,planes,stride=1,downsample=None):
        super(SKBlock,self).__init__()
        self.conv1=nn.Sequential(nn.Conv2d(inplanes,planes,1,1,0,bias=False),
                                 nn.BatchNorm2d(planes),
                                 nn.ReLU(inplace=True))
        self.conv2=SKConv(planes,planes,stride)
        self.conv3=nn.Sequential(nn.Conv2d(planes,planes*self.expansion,1,1,0,bias=False),
                                 nn.BatchNorm2d(planes*self.expansion))
        self.relu=nn.ReLU(inplace=True)
        self.downsample=downsample
    def forward(self, input):
        shortcut=input
        output=self.conv1(input)
        output=self.conv2(output)
        output=self.conv3(output)
        if self.downsample is not None:
            shortcut=self.downsample(input)
        output+=shortcut
        return self.relu(output)
class SKNet(nn.Module):
    '''
    参考 论文Table.1 进行构造
    '''
    def __init__(self,nums_class=1000,block=SKBlock,nums_block_list=[3, 4, 6, 3]):
        super(SKNet,self).__init__()
        self.inplanes=64
        # in_channel=3  out_channel=64  kernel=7x7 stride=2 padding=3
        self.conv=nn.Sequential(nn.Conv2d(3,64,7,2,3,bias=False),
                                nn.BatchNorm2d(64),
                                nn.ReLU(inplace=True))
        self.maxpool=nn.MaxPool2d(3,2,1) # kernel=3x3 stride=2 padding=1
        self.layer1=self._make_layer(block,128,nums_block_list[0],stride=1) # 构建表中 每个[] 的部分
        self.layer2=self._make_layer(block,256,nums_block_list[1],stride=2)
        self.layer3=self._make_layer(block,512,nums_block_list[2],stride=2)
        self.layer4=self._make_layer(block,1024,nums_block_list[3],stride=2)
        self.avgpool=nn.AdaptiveAvgPool2d(1) # GAP全局平均池化
        self.fc=nn.Linear(1024*block.expansion,nums_class) # 通道 2048 -> 1000
        self.softmax=nn.Softmax(-1) # 对最后一维进行softmax
    def forward(self, input):
        output=self.conv(input)
        output=self.maxpool(output)
        output=self.layer1(output)
        output=self.layer2(output)
        output=self.layer3(output)
        output=self.layer4(output)
        output=self.avgpool(output)
        output=output.squeeze(-1).squeeze(-1)
        output=self.fc(output)
        output=self.softmax(output)
        return output
    def _make_layer(self,block,planes,nums_block,stride=1):
        downsample=None
        if stride!=1 or self.inplanes!=planes*block.expansion:
            downsample=nn.Sequential(nn.Conv2d(self.inplanes,planes*block.expansion,1,stride,bias=False),
                                     nn.BatchNorm2d(planes*block.expansion))
        layers=[]
        layers.append(block(self.inplanes,planes,stride,downsample))
        self.inplanes=planes*block.expansion
        for _ in range(1,nums_block):
            layers.append(block(self.inplanes,planes))
        return nn.Sequential(*layers)
def SKNet50(nums_class=1000):
    return SKNet(nums_class,SKBlock,[3, 4, 6, 3]) # 论文通过[3, 4, 6, 3]搭配出SKNet50
def SKNet101(nums_class=1000):
    return SKNet(nums_class,SKBlock,[3, 4, 23, 3])
if __name__=='__main__':
    x = torch.rand(2, 3, 224, 224)
    model=SKNet50()
    y=model(x)
    print(y) # shape [2,1000]

参考:SKNet学习和使用-pytorch

【深度学习注意力机制系列】—— SKNet注意力机制(附pytorch实现)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1134634.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

composer安装thinkphp6报错

composer安装thinkphp6报错, 查看是否安装了对应的PHP扩展,我这边使用的是宝塔的环境,全程可以可视化操作 这样就可以安装完成了

linux工具篇

文章目录 linux工具篇1. linux 软件包管理器-yum1.1 什么是软件包1.2 yum的使用1.3 yum源 2. linux编辑器-vim2.1 vim概念2.2 vim各个模式切换2.3 vim正常模式命令汇总2.4 vim底行模式各命令汇总2.5 vim的简单配置 3. Linux编译器-gcc/g使用3.1 复习程序编译过程(1) 预处理(2) …

【Oracle】Navicat Premium 连接 Oracle的两种方式

Navicat Premium 使用版本说明 Navicat Premium 版本 11.2.16 (64-bit) 一、配置OCI 1.1 配置OCI环境变量 1.1.2 设置\高级系统设置 1.1.2 系统属性\高级\环境变量(N) 1.1.3 修改/添加系统变量 ORACLE_HOME ORACLE_HOME D:\app\root\product\12.1.0\dbhome_11.1.4 添加系…

Redis快速上手篇(一)(安装与配置)

NoSQL NoSQL 是 Not Only SQL 的缩写,意即"不仅仅是 SQL"的意思,泛指非关系型的数据库。强调 Key-Value Stores 和文档数据库的优点。 NoSQL 产品是传统关系型数据库的功能阉割版本,通过减少用不到或很少用的功能,来大…

vue3中使用svg并封装成组件

打包svg地图 安装插件 yarn add vite-plugin-svg-icons -D # or npm i vite-plugin-svg-icons -D # or pnpm install vite-plugin-svg-icons -D使用插件 vite.config.ts import { VantResolver } from unplugin-vue-components/resolvers import { createSvgIconsPlugin } from…

【C语言初阶】switch语句的基本语法

🎬 鸽芷咕:个人主页 🔥 个人专栏:《速学C语言》《数据结构篇》 ⛺️生活的理想,就是为了理想的生活! 文章目录 前言💬 switch语句的介绍💬 switch语句的语法形式💭 在switch语句中的 break&…

景联文科技提供4D-BEV标注工具:提升自动驾驶感知能力的精准数据支持

4D-BEV标注是一种用于自动驾驶领域的数据标注方法。在3D空间的基础上,加入了时间维度,形成了四个维度。这种方法通过精准地跟踪和记录动态对象(如车辆、行人)的运动轨迹、姿势变化以及速度等信息,全面理解和分析动态对…

JWT的登录认证与自校验原理分析

目录 一、JWT的概述 1.什么是JWT? 2.JWT的用户认证 3.JWT解决了什么问题? 4.关于JWT中的签名如何理解? 5.JWT的优势 二、JWT的结构 1.令牌的组成: 2.JWT的工具类 3.JWT所需的依赖 4.JWT登录生成Token的原理 三、JWT的自…

Linux中关于glibc包导致的服务器死机或者linux命令无法使用的情况

glibc是gnu发布的libc库,即c运行库。glibc是linux系统中最底层的api,几乎其它任何运行库都会依赖于glibc。glibc除了封装linux操作系统所提供的系统服务外,它本身也提供了许多其它一些必要功能服务的实现。由于 glibc 囊括了几乎所有的 UNIX …

智能巡检系统可以应用在哪些地方?巡检系统有什么优势?

智能巡检系统可以广泛应用于学校、物业、工厂、酒店和民宿、运维商以及能源行业等不同领域的巡检管理之中,为用户提供了巡检任务安排管理、签到打卡、工单上报处理以及数据统计分析等多种功能。 一、巡检系统的应用场景   1、学校:为了确保学校各项设施…

【抓包分析】通过ChatGPT解密还原某软件登录算法实现绕过手机验证码登录

文章目录 🍋前言实现效果广告抓包分析一、定位加密文件二、编辑JS启用本地替换 利用Chatgpt进行代码转换最后 🍋前言 由于C站版权太多,所有的爬虫相关均为记录,不做深入! 今天发现github上没有这个东西,抓…

Centos7 防火墙的关闭

Centos7 防火墙的关闭 Centos7默认使用的是firewall作为防火墙; 查看防火墙状态:firewall-cmd --state 停止防火墙:systemctl stop firewalld.service; 禁止防火墙开机启动:systemctl disable firewalld.service; 放行端口…

项目总结-新增商品-Pagehelper插件分页查询

(1)新增商品 工具类: /** * Title: FileUtils.java * Package com.qfedu.common.utils * Description: TODO(用一句话描述该文件做什么) * author Feri * date 2018年5月29日 * version V1.0 */ package com.gdsdxy.common.u…

webpack 解决:TypeError: merge is not a function 的问题

1、问题描述: 其一、存在的问题为: TypeError: merge is not a function 中文为: 类型错误:merge 不是函数 其二、问题描述为: 想执行 npm run dev 命令,运行起项目时,控制台报错 TypeErro…

066:mapboxGL的marker的drag,dragstart,dragend三种触发事件示例

第066个 点击查看专栏目录 本示例是演示如何在vue+mapbox中处理marker的三种触发事件drag,dragstart,dragend。 marker通过on(‘XXX’, callback),的方式进行触发处理。 直接复制下面的 vue+mapbox源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方式示例源代码(…

医学YOLOv8 | 脑肿瘤检测 Accuracy 99%

在医疗保健领域,准确和高效地识别脑肿瘤是一个重大挑战。本文中,我们将探讨一种使用 YOLOv8,一种先进的目标检测模型,将脑肿瘤进行分类的新方法,其准确率达到了 99%。通过将深度学习与医学图像相结合,我们希…

Python算法例1 完美平方

例1 完美平方 1. 问题描述 给定一个正整数n,找到若干个完全平方数(例如:1,4,9,…),使得它们的和等于n,完全平方数的个数最少。 2.问题示例 给出n8,返回2&…

【51单片机】:智能施工电梯系统

项目效果: 基于51单片机的智能施工电梯系统 摘 要 智能施工电梯系统目前广泛应用于人们建筑工程中,为人们施工时上下搬运提供了极大的便利。智能施工电梯系统包括密码开启、超重提示,电梯运作及相关信息显示等等功能,施工电梯为我…

数组中出现次数超过一半的数字整型数组有一个数字出现的次数超过总数的一半,请找出该数字

例如长度为 9 的数组{1,2,3,2,4,2,5,2,2}。 由于 2 出现的次数是 5 次,超过一半,所以结果为2。 算法一: 先排序,然后中间值就是要找的数字 图解: int Cmp_int(const void* vp1, const void* vp2) //定义排序规则 {return * (int*)vp1 - *(int*)vp2; } …

ts | js | 爬虫小公举分享

Curl转Code 快速将curl转为各种语言的代码; 便于提取请求头之类, 或者微改直接使用 https://curlconverter.com/node-axios/ (有点慢, 但是很全)https://www.lddgo.net/convert/curl-to-code (没有axios, 我喜欢用axios) 使用… 抓取地址, 使用浏览器或者其他抓包工具都可, 这…