[yolo系列:YOLOV7改进-添加CoordConv,SAConv.]

news2024/12/23 5:43:43

文章目录

    • 概要
    • CoordConv
    • SAConv

概要

CoordConv(Coordinate Convolution)和SAConv(Spatial Attention Convolution)是两种用于神经网络中的特殊卷积操作,用于处理图像数据或其他多维数据。以下是它们的简要介绍:
CoordConv(Coordinate Convolution)

CoordConv 是由Uber AI Labs的研究人员提出的一种卷积操作,用于处理图像中的坐标信息。在传统的卷积操作中,卷积核在图像上滑动并执行卷积操作,但是它们对于图像中的位置信息是不敏感的。CoordConv 的目标是使卷积操作变得位置敏感,它在输入特征图中加入了位置信息作为额外的通道。这个位置信息可以是像素的坐标,也可以是归一化的坐标值,具体取决于应用的场景。

通过将坐标信息与输入特征图拼接在一起,CoordConv 能够帮助神经网络更好地学习到输入数据中的空间关系,从而提高模型的性能。它在需要考虑输入数据的空间位置信息时,特别有用。
SAConv(Spatial Attention Convolution)

SAConv 是一种引入了空间注意力机制的卷积操作。传统的卷积操作在所有位置都应用相同的卷积核,而SAConv 具有可学习的空间注意力权重,这意味着它能够动态地调整不同位置的卷积核权重。

SAConv 的关键思想是,在进行卷积操作之前,先计算每个位置的空间注意力权重。这些权重由神经网络学习得出,然后被用来加权输入特征图的不同位置,从而生成具有位置敏感性的特征表示。这种机制使得神经网络在处理输入数据时能够更加关注重要的区域,从而提高了模型的感知能力和性能。

总的来说,CoordConv 和 SAConv 都是为了增强神经网络对输入数据的空间信息处理能力而提出的方法。CoordConv 引入了位置信息通道,使得网络对位置信息更敏感,而 SAConv 引入了空间注意力机制,使得网络能够动态地调整卷积核的权重,提高了对不同位置信息的关注度。这两种方法在特定的任务和场景下都能够带来性能的提升。

CoordConv

common.py添加如下

class AddCoords(nn.Module):
    def __init__(self, with_r=False):
        super().__init__()
        self.with_r = with_r

    def forward(self, input_tensor):
        """
        Args:
            input_tensor: shape(batch, channel, x_dim, y_dim)
        """
        batch_size, _, x_dim, y_dim = input_tensor.size()

        xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
        yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)

        xx_channel = xx_channel.float() / (x_dim - 1)
        yy_channel = yy_channel.float() / (y_dim - 1)

        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
        yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)

        ret = torch.cat([
            input_tensor,
            xx_channel.type_as(input_tensor),
            yy_channel.type_as(input_tensor)], dim=1)

        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)

        return ret

class CoordConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, with_r=False):
        super().__init__()
        self.addcoords = AddCoords(with_r=with_r)
        in_channels += 2
        if with_r:
            in_channels += 1
        self.conv = Conv(in_channels, out_channels, k=kernel_size, s=stride)

    def forward(self, x):
        x = self.addcoords(x)
        x = self.conv(x)
        return x

在yolo.py

在这里插入图片描述

# yolov7 head
head:
  [[-1, 1, SPPCSPC, [512]], # 51
  
   [-1, 1, CoordConv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [37, 1, CoordConv, [256, 1, 1]], # route backbone P4
   [[-1, -2], 1, Concat, [1]],
   
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]], # 63
   
   [-1, 1, CoordConv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [24, 1, CoordConv, [128, 1, 1]], # route backbone P3
   [[-1, -2], 1, Concat, [1]],
   
   [-1, 1, Conv, [128, 1, 1]],
   [-2, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [64, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [128, 1, 1]], # 75
      
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3, 63], 1, Concat, [1]],
   
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]], # 88
      
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, -3, 51], 1, Concat, [1]],
   
   [-1, 1, Conv, [512, 1, 1]],
   [-2, 1, Conv, [512, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1]], # 101
   
   [75, 1, CoordConv, [256, 3, 1]],
   [88, 1, CoordConv, [512, 3, 1]],
   [101, 1, CoordConv, [1024, 3, 1]],

   [[102,103,104], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

SAConv

在common.py添加

class ConvAWS2d(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.register_buffer('weight_gamma', torch.ones(self.out_channels, 1, 1, 1))
        self.register_buffer('weight_beta', torch.zeros(self.out_channels, 1, 1, 1))

    def _get_weight(self, weight):
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        weight = weight / std
        weight = self.weight_gamma * weight + self.weight_beta
        return weight

    def forward(self, x):
        weight = self._get_weight(self.weight)
        return super()._conv_forward(x, weight, None)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        self.weight_gamma.data.fill_(-1)
        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
                                      missing_keys, unexpected_keys, error_msgs)
        if self.weight_gamma.data.mean() > 0:
            return
        weight = self.weight.data
        weight_mean = weight.data.mean(dim=1, keepdim=True).mean(dim=2,
                                       keepdim=True).mean(dim=3, keepdim=True)
        self.weight_beta.data.copy_(weight_mean)
        std = torch.sqrt(weight.view(weight.size(0), -1).var(dim=1) + 1e-5).view(-1, 1, 1, 1)
        self.weight_gamma.data.copy_(std)
    
class SAConv2d(ConvAWS2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 s=1,
                 p=None,
                 g=1,
                 d=1,
                 act=True,
                 bias=True):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=s,
            padding=autopad(kernel_size, p),
            dilation=d,
            groups=g,
            bias=bias)
        self.switch = torch.nn.Conv2d(
            self.in_channels,
            1,
            kernel_size=1,
            stride=s,
            bias=True)
        self.switch.weight.data.fill_(0)
        self.switch.bias.data.fill_(1)
        self.weight_diff = torch.nn.Parameter(torch.Tensor(self.weight.size()))
        self.weight_diff.data.zero_()
        self.pre_context = torch.nn.Conv2d(
            self.in_channels,
            self.in_channels,
            kernel_size=1,
            bias=True)
        self.pre_context.weight.data.fill_(0)
        self.pre_context.bias.data.fill_(0)
        self.post_context = torch.nn.Conv2d(
            self.out_channels,
            self.out_channels,
            kernel_size=1,
            bias=True)
        self.post_context.weight.data.fill_(0)
        self.post_context.bias.data.fill_(0)
        
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        # pre-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=1)
        avg_x = self.pre_context(avg_x)
        avg_x = avg_x.expand_as(x)
        x = x + avg_x
        # switch
        avg_x = torch.nn.functional.pad(x, pad=(2, 2, 2, 2), mode="reflect")
        avg_x = torch.nn.functional.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
        switch = self.switch(avg_x)
        # sac
        weight = self._get_weight(self.weight)
        out_s = super()._conv_forward(x, weight, None)
        ori_p = self.padding
        ori_d = self.dilation
        self.padding = tuple(3 * p for p in self.padding)
        self.dilation = tuple(3 * d for d in self.dilation)
        weight = weight + self.weight_diff
        out_l = super()._conv_forward(x, weight, None)
        out = switch * out_s + (1 - switch) * out_l
        self.padding = ori_p
        self.dilation = ori_d
        # post-context
        avg_x = torch.nn.functional.adaptive_avg_pool2d(out, output_size=1)
        avg_x = self.post_context(avg_x)
        avg_x = avg_x.expand_as(out)
        out = out + avg_x
        return self.act(self.bn(out))

然后在yolo.py里面添加
在这里插入图片描述
在这里插入图片描述
和可变形卷积加法一样,但是不建议加太多,也是只替换3x3卷积上面。比普通卷积复杂度高,不建议加太多,推理速度变慢,尽量少用,提高精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1127324.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【小余送书活动第四期】《Kali Linux高级渗透测试》,不可多的的网安书籍哦!网络安全的朋友抓紧参与活动领书咯!

目录 1.背景介绍 2.读者对象 3.随书资源 4.本书目录 5.本书概览 6.活动参与方式 1.背景介绍 对于企业网络安全建设工作的质量保障,业界普遍遵循PDCA(计划(Plan)、实施(Do)、检查(Check&…

CUDA学习笔记(十一)Memory Access

转载于https://www.cnblogs.com/1024incn/tag/CUDA/ Memory Access Patterns 大部分device一开始从global Memory获取数据,而且,大部分GPU应用表现会被带宽限制。因此最大化应用对global Memory带宽的使用时获取高性能的第一步。也就是说,gl…

vsCode 格式化配置

学习目标: 基于 vsCode 配置格式化工具,提高(React、Vue )开发效率  1. vsCode 安装 prettier 插件并启用  2. 修改配置文件 setting.json setting.json 位置: 依次点击 替换内容:↓ {"git.enab…

强化学习代码实战(2) --- 多臂赌博机

目录 前言 1.Python基础 2.Numpy基础 3.多臂赌博机 参考文献 前言 本文内容来自于南京大学郭宪老师在博文视点学院录制的视频,课程仅9元地址,配套书籍为深入浅出强化学习 编程实战 郭宪地址。 1.Python基础 1. print() 可以用该语句查看当前数据的情…

使用线程时,有哪三种常见的线程安全问题

Java全能学习面试指南:https://javaxiaobear.cn 今天我们学习 3 类线程安全问题。 什么是线程安全 要想弄清楚有哪 3 类线程安全问题,首先需要了解什么是线程安全,线程安全经常在工作中被提到,比如:你的对象不是线程…

程序员必备网站,别说话直接收藏!

俗话说的好,一个程序员,20%靠知识储备,80%靠网络搜索。打开代码,打开Google,开始工作。 那么常用的写码软件,你知道几个呢?下面我们来一起看一下常用的写码软件吧~建议收藏本文,保证…

docsify搭建个人博客——简单公共知识库

整站建设流程:安装docsify > 排错>配置封面> 配搜索> 启动> 放md类的文章> 自动生成目录; 更新文章流程: 把目录文章放到docsify\docs目录下,然后双击docsify-autosidebar.exe即可(它会重新生成目录…

2023版 STM32实战11 SPI总线读写W25Q

SPI全称 英文全称:Serial peripheral Interface 串行外设接口 SPI特点 -1- 串行(逐bit传输) -2- 同步(共用时钟线) -3- 全双工(收发可同时进行) -4- 通信只能由主机发起(一主,多从机) 开发使用习惯和理解 -1- CS片选一般配置为软件控制 -2- 片选低电平有效,从…

JAVA入门总结回顾

1.常用的DOS命令:DOS窗口常用命令-CSDN博客 2.检查jdk是否安装成功:在cmd中输入java -version或者java或者javac。出现相应的对应显示内容。 3.JDK,JRE之间的关系:JDK是JAVA的开发工具包,JRE是JAVA的的运行环境。JRE…

第二证券:指数是什么意思?

跟着经济全球化的加速和商场化进程的深化,指数已成为金融商场重要的风向标和抉择方案参看。指数是依据商场上必定数量的标的股票价格改变而核算的数值,代表了特定股票商场的全体涨跌状况。本文将从多个视点剖析指数的意义和作用。 一、指数的品种和核算…

k8s部署xxl-job后,执行任务提示拒绝连接Connection Refused

一、问题背景 1.1 问题说明 之前由于网络插件flannel安装不成功,导致xxl-job执行任务的时候,提示拒绝服务,如下图所示: 但是安装flannel安装成功后,依然无法联通,还是提示相同问题 1.2 排查网络 通过i…

短视频矩阵系统搭建/源头----源码

一、智能剪辑、矩阵分发、无人直播、爆款文案于一体独立应用开发 抖去推----主要针对本地生活的----移动端(小程序软件系统,目前是全国源头独立开发),开发功能大拆解分享,功能大拆解: 7大模型剪辑法(数学阶乘&#xff…

Golang 数据库操作

文章目录 初始化连接连接池SetMaxOpenConnsSetMaxIdleConnsSetConnMaxIdleTimeSetConnMaxLifetime 查询数据插入数据更新数据删除数据实现账号密码登录功能sqlx的部分用法 首先安装包:Install go get -u github.com/go-sql-driver/mysql // MySQL数据库的包 go get…

原型制作的软件 Experience Design mac( XD ) 中文版软件特色

​XD是一个直观、功能强大的UI/UX开发工具,旨在设计、原型、用户之间共享材料以及通过数字技术进行设计交互。Adobe XD提供了开发网站、应用程序、语音界面、游戏界面、电子邮件模板等所需的一切。xd mac软件特色 体验设计的未来。 使用 Adobe XD 中快速直观、即取即…

女儿的睡衣,蕾丝花边蝴蝶结,太好看了吧

分享女儿的睡衣穿搭 大部分女孩子都喜欢 粉粉嫩小公主风格的衣服 我的宝贝也不例外啦 蕾丝花边和蝴蝶结真的会让女生少女心爆棚 非常厚实软糯的珊瑚绒质地,穿上非常暖和 裤脚和袖口都做了收口设计,真的很赞! 还有麻麻款呢,…

C算法:使用选择排序实现从(大到小/从小到大)排序数组,且元素交换不可使用第三变量。

需求&#xff1a; 使用选择排序实现从(大到小/从小到大)排序&#xff0c;且元素交换不可使用第三变量 (异或交换法) 代码实现&#xff1a; #include <stdio.h> void maopao(int* array,int len,int(*swap)(int a,int b)) {int i,j;for(i0;i<len-1;i){for(ji1;j<…

如何判断LED透明屏质量好坏?

要判断LED透明屏的质量好坏&#xff0c;您可以考虑以下几个关键因素&#xff1a; 焊点品质。焊点饱满的证明焊接工艺好&#xff0c;亮度高的透明屏&#xff0c;证明焊锡用的好&#xff1b;品质不好的是虚焊&#xff0c;容易出现接触不良现象。 灯珠温度。点亮一段时间后&#x…

如何删除重复文件?简单操作法方法盘点!

“我之前传文件的时候好像传了很多重复的&#xff0c;导致这些文件占用了我大量的内存&#xff0c;有什么方法可以快速删除这些重复的文件吗&#xff1f;感谢&#xff01;” 随着时间的推移&#xff0c;我们的电脑中常常会积累大量的重复文件&#xff0c;这不仅占用宝贵的存储空…

竞赛选题 深度学习人脸表情识别算法 - opencv python 机器视觉

文章目录 0 前言1 技术介绍1.1 技术概括1.2 目前表情识别实现技术 2 实现效果3 深度学习表情识别实现过程3.1 网络架构3.2 数据3.3 实现流程3.4 部分实现代码 4 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习人脸表情识别系…

Allegro教学:Assembly层和Silkscreen元器件编号如何处理?

在电子工程中&#xff0c;PCB的设计和制造最为关键&#xff0c;而PCB上有多种层&#xff0c;有信号层、电源层、接地层和机械层&#xff0c;今天我们来聊聊Assembly层。来聊聊Silkscreen元器件编号问题&#xff0c;希望本文对小伙伴们有所帮助。 首先在回答这个问题前&#xff…