pytorch代码实现之动态卷积模块ODConv

news2024/9/30 17:23:07

ODConv动态卷积模块

ODConv可以视作CondConv的延续,将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为全维度动态卷积。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力。作为一种“即插即用”的操作,它可以轻易的嵌入到现有CNN网络中。ImageNet分类与COCO检测任务上的实验验证了所提ODConv的优异性:即可提升大模型的性能,又可提升轻量型模型的性能,实乃万金油是也!值得一提的是,受益于其改进的特征提取能力,ODConv搭配一个卷积核时仍可取得与现有多核动态卷积相当甚至更优的性能。

原文地址:Omni-Dimensional Dynamic Convolution

ODConv结构图
代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd
from models.common import Conv, autopad

class Attention(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
        super(Attention, self).__init__()
        attention_channel = max(int(in_planes * reduction), min_channel)
        self.kernel_size = kernel_size
        self.kernel_num = kernel_num
        self.temperature = 1.0

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = Conv(in_planes, attention_channel, act=nn.ReLU(inplace=True))

        self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
        self.func_channel = self.get_channel_attention

        if in_planes == groups and in_planes == out_planes:  # depth-wise convolution
            self.func_filter = self.skip
        else:
            self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
            self.func_filter = self.get_filter_attention

        if kernel_size == 1:  # point-wise convolution
            self.func_spatial = self.skip
        else:
            self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
            self.func_spatial = self.get_spatial_attention

        if kernel_num == 1:
            self.func_kernel = self.skip
        else:
            self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
            self.func_kernel = self.get_kernel_attention

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def update_temperature(self, temperature):
        self.temperature = temperature

    @staticmethod
    def skip(_):
        return 1.0

    def get_channel_attention(self, x):
        channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return channel_attention

    def get_filter_attention(self, x):
        filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return filter_attention

    def get_spatial_attention(self, x):
        spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
        spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
        return spatial_attention

    def get_kernel_attention(self, x):
        kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
        kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
        return kernel_attention

    def forward(self, x):
        x = self.avgpool(x)
        x = self.fc(x)
        return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)


class ODConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, k, s=1, p=None, g=1, act=True, d=1,
                 reduction=0.0625, kernel_num=1):
        super(ODConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = k
        self.stride = s
        self.padding = autopad(k, p)
        self.dilation = d
        self.groups = g
        self.kernel_num = kernel_num
        self.attention = Attention(in_planes, out_planes, k, groups=g,
                                   reduction=reduction, kernel_num=kernel_num)
        self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//g, k, k),
                                   requires_grad=True)
        self._initialize_weights()
        self.bn = nn.BatchNorm2d(out_planes)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

        if self.kernel_size == 1 and self.kernel_num == 1:
            self._forward_impl = self._forward_impl_pw1x
        else:
            self._forward_impl = self._forward_impl_common

    def _initialize_weights(self):
        for i in range(self.kernel_num):
            nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')

    def update_temperature(self, temperature):
        self.attention.update_temperature(temperature)

    def _forward_impl_common(self, x):
        # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
        # while we observe that when using the latter method the models will run faster with less gpu memory cost.
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        batch_size, in_planes, height, width = x.size()
        x = x * channel_attention
        x = x.reshape(1, -1, height, width)
        aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
        aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
            [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
        output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups * batch_size)
        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        output = output * filter_attention
        return output

    def _forward_impl_pw1x(self, x):
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        x = x * channel_attention
        output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups)
        output = output * filter_attention
        return output

    def forward(self, x):
        return self.act(self.bn(self._forward_impl(x)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1025253.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全局异常处理器@RestControllerAdvice解析 Springboot项目异常处理 JavaWeb @ExceptionHandler

RestControllerAdvice public class GlobalExceptionHandler {ExceptionHandler(Exception.class)//指定捕获异常类型:所有public Result ex(Exception ex){ex.printStackTrace();return Result.error("对不起,出现异常,请联系管理员");}}RestControllerAdvice注解在…

基于Yolov8的野外烟雾检测(1)

目录 1.Yolov8介绍 2.野外火灾烟雾数据集介绍 2.1数据集划分 1.2 通过voc_label.py得到适合yolov8需要的txt 2.3生成内容如下 3.训练结果分析 4.系列篇 1.Yolov8介绍 Ultralytics YOLOv8是Ultralytics公司开发的YOLO目标检测和图像分割模型的最新版本。YOLOv8是一种尖端的…

1千听歌猜歌名疯狂猜歌ACCESS\EXCEL数据库

就是从今年开始,各类的“猜”游戏开始火爆,先是猜图,比如看图猜明星、看图猜成语、看图猜电影、看图猜电视剧、看图猜背景、看图猜游戏、看图猜影视人物、看图猜景点等。然后又开始猜音频,猜音频最多的是歌。甚至现在的《一站到底…

每日一题:请解释什么是闭包(Closure)?并举一个实际的例子来说明。(前端初级)

今天继续在前端初级笔试题中被AI虐: 碱面的答案,问题:初级,回答:初级https://bs.rongapi.cn/1702510598371151872/14我的回答如下: 闭包是指由大括号包裹的一个区域,这个区域代表了一个变量生效…

【数据分享】我国六普的乡镇(街道)人口数据(免费获取)

人口数据是我们在各项研究中都经常使用的数据!人口数据的主要来源是人口普查,全国性的人口普查每十年进行一次。最近三次的人口普查分别是:2000年的第五次全国人口普查,简称五普;2010年的第六次全国人口普查&#xff0…

海外网红营销安全指南:品牌必须遵守的10大法律法规

随着互联网的普及和社交媒体的崛起,品牌们越来越倾向于与海外网红合作,以扩大其在全球市场的影响力。然而,这一战略并非没有风险,因为在不同国家和地区,存在着各种各样的法律法规,可能会影响品牌与海外网红…

JavaScript小案例-tab栏切换(可移除item)

gif效果图&#xff1a; 代码&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>tab栏切…

使用JQ获取并渲染三级联动分类数据

数据JSON格式 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </he…

idea 使用 groovyScript 获取方法参数列表生成方法注释模板遇到的问题。

1、网上好多使用groovyScript来设置获取方法列表生成注释模板的代码&#xff0c;我这篇文章的是想讨论下这种方式存在的一个问题&#xff0c;希望有大佬能提供一个解决方案。 2、设置步骤什么的就省略了直接描述问题。 3、groovyScript代码段如下&#xff1a; groovyScript(…

SpringBoot统一返回处理遇到cannot be cast to java.lang.String问题

ResponseBodyAdvice 接口概述 1、ResponseBodyAdvice 接口允许在执行 ResponseBody 或 ResponseEntity 控制器方法之后&#xff0c;但在使用 HttpMessageConverter 写入响应体之前自定义响应&#xff0c;进行功能增强。通常用于 加密&#xff0c;签名&#xff0c;统一数据格式…

从零开始—【Mac系统】MacOS配置Java环境变量

系统环境说明 Apple M1 macOS Ventura 版本13.5.2 1.下载JDK安装包 Oracle官网下载地址 JDK下载【注&#xff1a;推荐下载JDK8 Oracle官网JDK8下载】 关于JDK、JRE、JVM的关系说明 JDK(Java Development Kit&#xff0c;Java开发工具包) &#xff0c;是整个JAVA的核心&#…

Mybatis学习笔记10 高级映射及延迟加载

Mybatis学习笔记9 动态SQL_biubiubiu0706的博客-CSDN博客 无论简单映射(前面所学的单表和对象之间的映射关系)还是高级映射 说到底都是java对象和数据库表记录之间的映射关系 准备数据库表:一个班级对应多个学生.班级表:t_class 学生表:s_stu(自增) 新建模块 项目整体结构 …

minio文件上传

1.代码 大佬仓库&#xff1a;https://gitee.com/Gary2016/minio-upload?_fromgitee_search 关于这个代码的讲解&#xff1a;来自b站 2.准备minio 参考&#xff1a;[1]、[2] 2.1 下载 官网&#xff1a;https://min.io/download#/windows 2.2 启动 ①准备一个data文件夹…

Vue.js模板语法[下](事件处理,表单综合案例,自定义组件)---详细讲解

一&#xff0c;事件处理 1. .stop&#xff1a;阻止事件冒泡。使用该修饰符可以阻止事件向父元素传播 2. .prevent&#xff1a;阻止默认事件。使用该修饰符可以阻止事件的默认行为。 3. .capture&#xff1a;使用事件捕获模式。默认情况下&#xff0c;事件是在冒泡阶段处理的&am…

第七章 查找 三、折半查找(二分查找)

一、代码实现 此代码只能用于查找有序的顺序表 typedef struct {int *e;int len; }SSTable;int Search_Seq(SSTable st,int t){int i0,jst.len-1,mid;while (i<j){mid(ij)>>2;if (t>st.e[mid]){imid1;} else if (t<st.e[mid]){jmid-1;} else{return mid;}}ret…

数字孪生行业相关政策梳理--工业领域相关政策(可下载)

自2021年国家“十四五”规划纲要提出“探索建设数字孪生城市”以来&#xff0c;国家发展和改革委员会、工业和信息化部、住房和城乡建设部、水利部、农业农村部等部门纷纷出台政策&#xff0c;大力推动数字孪生在千行百业的落地发展。这些政策不仅为数字孪生的应用提供了广阔的…

从烹饪一道菜看面向过程与面向对象编程

在编程世界中&#xff0c;面向过程和面向对象是两种主要的编程范式。它们各有优点&#xff0c;适用于不同的场景。让我们通过烹饪一道菜的例子来理解这两种编程范式。 面向过程编程 面向过程编程是一种基于过程的编程范式&#xff0c;它强调的是程序的执行顺序。在这种范式中…

E. Moment of Bloom

Problem - E - Codeforces 思路&#xff1a;这个题看到之后想到了不可能的情况&#xff0c;就是如果度为奇数就一定不可能实现都是偶数&#xff0c;但是后面就不知道怎么搞了。正解是欧拉定理的应用把算是&#xff0c;首先对于给定的q个要求&#xff0c;我们从a->b连一条边&…

win10 win11 停止系统自动更新方法

目录 方法一&#xff1a;使用注册表更改 1. 进入注册表 2. 进入如下目录 3. 新建 DWOED(32-位)值 4. 双击 FlightSettingsMaxPauseDays&#xff0c;选择十进制&#xff0c;左侧输入9999 5. 开头的天数已经变为9999天 方法二&#xff1a;停止自动更新的服务 1. 查询服务…

许少辉八一新著《乡村振兴战略下传统村落文化旅游设计》安徽站——2023学生开学季辉少许

许少辉八一新著《乡村振兴战略下传统村落文化旅游设计》安徽站——2023学生开学季辉少许