YOLOv7改进:结合CotNet Transformer结构

news2024/11/30 4:58:11

 

1.简介

京东AI研究院提出的一种新的注意力结构。将CoT Block代替了ResNet结构中的3x3卷积,在分类检测分割等任务效果都出类拔萃

 

论文:Contextual Transformer Networks for Visual Recognition
论文地址:https://arxiv.org/abs/2107.12292

 

有自注意力的Transformer引发了自然语言处理领域的革命,最近还激发了Transformer式架构设计的出现,并在众多计算机视觉任务中取得了具有竞争力的结果。

大多数现有设计直接在2D特征图上使用自注意力来获得基于每个空间位置的独立查询和键对的注意力矩阵,但未充分利用相邻键之间的丰富上下文。在今天分享的工作中,研究者设计了一个新颖的Transformer风格的模块,即Contextual Transformer (CoT)块,用于视觉识别。这种设计充分利用输入键之间的上下文信息来指导动态注意力矩阵的学习,从而增强视觉表示能力。从技术上讲,CoT块首先通过3×3卷积对输入键进行上下文编码,从而产生输入的静态上下文表示。

 

上图a是传统的self-attention仅利用孤立的查询-键对来测量注意力矩阵,但未充分利用键之间的丰富上下文。 b就是CoT块

研究者进一步将编码的键与输入查询连接起来,通过两个连续的1×1卷积来学习动态多头注意力矩阵。学习到的注意力矩阵乘以输入值以实现输入的动态上下文表示。静态和动态上下文表示的融合最终作为输出。CoT块很吸引人,因为它可以轻松替换ResNet架构中的每个3 × 3卷积,产生一个名为Contextual Transformer Networks (CoTNet)的Transformer式主干。通过对广泛应用(例如图像识别、对象检测和实例分割)的大量实验,验证了CoTNet作为更强大的主干的优越性

Attention注意力机制与self-attention自注意力机制

为什么要注意力机制?

在Attention诞生之前,已经有CNN和RNN及其变体模型了,那为什么还要引入attention机制?主要有两个方面的原因,如下:

(1)计算能力的限制:当要记住很多“信息“,模型就要变得更复杂,然而目前计算能力依然是限制神经网络发展的瓶颈。

(2)优化算法的限制:LSTM只能在一定程度上缓解RNN中的长距离依赖问题,且信息“记忆”能力并不高。

什么是注意力机制

在介绍什么是注意力机制之前,先让大家看一张图片。当大家看到下面图片,会首先看到什么内容?当过载信息映入眼帘时,我们的大脑会把注意力放在主要的信息上,这就是大脑的注意力机制。

同样,当我们读一句话时,大脑也会首先记住重要的词汇,这样就可以把注意力机制应用到自然语言处理任务中,于是人们就通过借助人脑处理信息过载的方式,提出了Attention机制。

self attention是注意力机制中的一种,也是transformer中的重要组成部分。自注意力机制是注意力机制的变体,其减少了对外部信息的依赖,更擅长捕捉数据或特征的内部相关性。自注意力机制在文本中的应用,主要是通过计算单词间的互相影响,来解决长距离依赖问题。

传统的自注意力很好地触发了不同空间位置的特征交互,具体取决于输入本身。然而,在传统的自注意力机制中,所有成对的查询键关系都是通过孤立的查询键对独立学习的,而无需探索其间的丰富上下文。这严重限制了自注意力学习在2D特征图上进行视觉表示学习的能力。

为了缓解这个问题,研究者构建了一个新的Transformer风格的构建块,即上图 (b)中的 Contextual Transformer (CoT) 块,它将上下文信息挖掘和自注意力学习集成到一个统一的架构中。

2.YOLOv7改进

 2.1yaml配置文件修改

# YOLOv7 🚀, GPL-3.0 license
# parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 1.0  # layer channel multiple

# anchors
anchors:
  - [12,16, 19,36, 40,28]  # P3/8
  - [36,75, 76,55, 72,146]  # P4/16
  - [142,110, 192,243, 459,401]  # P5/32

# yolov7 backbone by yoloair
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [32, 3, 1]],  # 0
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4 
   [-1, 1, C3HB, [128]], 
   [-1, 1, Conv, [256, 3, 2]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3], 1, Concat, [1]],  # 16-P3/8
   [-1, 1, Conv, [128, 1, 1]],
   [-2, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [1024, 1, 1]],          
   [-1, 1, MP, []],
   [-1, 1, Conv, [512, 1, 1]],
   [-3, 1, Conv, [512, 1, 1]],
   [-1, 1, Conv, [512, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, C3HB, [1024]],
   [-1, 1, Conv, [256, 3, 1]],
  ]

# yolov7 head by yoloair
head:
  [[-1, 1, SPPCSPC, [512]],
   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [31, 1, Conv, [256, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, CoT3, [128]],
   [-1, 1, Conv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [18, 1, Conv, [128, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, CoT3, [128]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3, 44], 1, Concat, [1]],
   [-1, 1, CoT3, [256]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]], 
   [[-1, -3, 39], 1, Concat, [1]],
   [-1, 3, CoT3, [512]],

# 检测头 -----------------------------
   [49, 1, RepConv, [256, 3, 1]],
   [55, 1, RepConv, [512, 3, 1]],
   [61, 1, RepConv, [1024, 3, 1]],

   [[62,63,64], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

2.1common.py配置

./models/common.py文件增加以下模块

class CoT3(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*[CoTBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))

class CoTBottleneck(nn.Module):
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super(CoTBottleneck, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = CoT(c_, 3)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class CoT(nn.Module):
    # Contextual Transformer Networks https://arxiv.org/abs/2107.12292
    def __init__(self, dim=512,kernel_size=3):
        super().__init__()
        self.dim=dim
        self.kernel_size=kernel_size

        self.key_embed=nn.Sequential(
            nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed=nn.Sequential(
            nn.Conv2d(dim,dim,1,bias=False),
            nn.BatchNorm2d(dim)
        )

        factor=4
        self.attention_embed=nn.Sequential(
            nn.Conv2d(2*dim,2*dim//factor,1,bias=False),
            nn.BatchNorm2d(2*dim//factor),
            nn.ReLU(),
            nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1)
        )


    def forward(self, x):
        bs,c,h,w=x.shape
        k1=self.key_embed(x) #bs,c,h,w
        v=self.value_embed(x).view(bs,c,-1) #bs,c,h,w

        y=torch.cat([k1,x],dim=1) #bs,2c,h,w
        att=self.attention_embed(y) #bs,c*k*k,h,w
        att=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)
        att=att.mean(2,keepdim=False).view(bs,c,-1) #bs,c,h*w
        k2=F.softmax(att,dim=-1)*v
        k2=k2.view(bs,c,h,w)
        
        return k1+k2

2.3yolo.py配置修改

然后找到./models/yolo.py文件下里的parse_model函数,将加入的模块名CoT3加入进去
在 models/yolo.py文件夹下

定位到parse_model函数中
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
对应位置 下方只需要增加 CoT3模块

 修改完成!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1051946.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C进阶--数据的存储

⚙ 1. 数据类型介绍 1.1基本内置类型 ⭕ 整形: char(char又叫短整型)unsigned charsigned charshortunsigned short[int]signed short [int]intunsigned intsigned intlongunsigned long [int]signed long [int] ⭕ 浮点数: float(单精度浮…

Linux下C语言操作网卡的几个代码实例?特别实用

前面写了一篇关于网络相关的文章:如何获取当前可用网口。 《简简单单教你如何用C语言列举当前所有网口!》 那么如何使用C语言直接操作网口? 比如读写IP地址、读写MAC地址等。 一、原理 主要通过系统用socket()、ioctl()、实现 int sock…

基于arduino的土壤湿度检测

1.总体设计框图 本浇花系统总体上分为硬件和软件两大组成部分。硬件部分包括Arduino UNO开发板、温湿度传感器、通信模块、浇水执行系统和液晶显示等。软件部分包括Android客户端。系统结构如图1所示 本浇花系统总体上分为硬件和软件两大组成部分。硬件部分包括Arduino UN…

LeetCode算法题---第2天

注:大佬解答来自LetCode官方题解 80.删除有序数组的重复项Ⅱ 1.题目 2.个人解答 var removeDuplicates function (nums) {let res [];for (let index 0; index < nums.length; index) {let num 0;if (res.includes(nums[index])) {for (let i 0; i < res.length; …

Python2020年06月Python二级 -- 编程题解析

题目一 数字转汉字 用户输入一个1~9&#xff08;包含1和9&#xff09;之间的任一数字&#xff0c;程序输出对应的汉字。 如输入2&#xff0c;程序输出“二”。可重复查询。 答案: 方法一 list1[一,二,三,四,五,六,七,八,九] while True:n int(input(请输入1~9之间任意一个数字…

Windows 安装CMake

CMake 简介 CMake是一个开源的、跨平台的自动化构建系統&#xff0c;用來管理软件构建的过程。 其用途主要包括&#xff1a; 1. 跨平台编译&#xff1a;CMake支援Windows&#xff0c;Mac OS&#xff0c;Linux等多种操作系統&#xff0c;且支援多数主流编译器如GCC&#xff0…

如何在 Elasticsearch 中使用 Openai Embedding 进行语义搜索

随着强大的 GPT 模型的出现&#xff0c;文本的语义提取得到了改进。 在本文中&#xff0c;我们将使用嵌入向量在文档中进行搜索&#xff0c;而不是使用关键字进行老式搜索。 什么是嵌入 - embedding&#xff1f; 在深度学习术语中&#xff0c;嵌入是文本或图像等内容的数字表示…

centos 7.9同时安装JDK1.8和openjdk11两个版本

1.使用的原因 在服务器上&#xff0c;有些情况因为有一些系统比较老&#xff0c;所以需要使用JDK8版本&#xff0c;但随着时间的发展&#xff0c;新的软件出来&#xff0c;一般都会使用比较新的JDK版本。所以就出现了我们标题的需求&#xff0c;一个系统内同时安装两个不同的版…

Bartende:Mac菜单栏图标管理软件

Bartender 是一款可以帮助用户更好地管理和组织菜单栏图标的 macOS 软件。它允许用户隐藏和重新排列菜单栏图标&#xff0c;从而减少混乱和杂乱。 以下是 Bartender 的主要特点&#xff1a; 菜单栏图标隐藏&#xff1a;Bartender 允许用户隐藏菜单栏图标&#xff0c;只在需要时…

【Vue】数据监视输入绑定

hello&#xff0c;我是小索奇&#xff0c;精心制作的Vue系列持续发放&#xff0c;涵盖大量的经验和示例&#xff0c;如有需要&#xff0c;可以收藏哈 本章给大家讲解的是数据监视&#xff0c;前面的章节已经更新完毕&#xff0c;后面的章节持续输出&#xff0c;有任何问题都可以…

孤举者难起,众行者易趋,openGauss 5.1.0版本正式发布!

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 哈喽&#xff01;大家好&#xff0c;我是【IT邦德】&#xff0c;江湖人称jeames007&#xff0c;10余年DBA及大数据工作经验 一位上进心十足的【大数据领域博主】&#xff01;&#x1f61c;&am…

基于SSM的教师办公管理的设计与实现(有报告)。Javaee项目。

演示视频&#xff1a; 基于SSM的教师办公管理的设计与实现&#xff08;有报告&#xff09;。Javaee项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&#xff0c;通过Spring S…

Selenium Webdriver自动化测试框架

最近正在编写selenium webdriver自动化框架&#xff0c;经过几天的努力&#xff0c;目前基本已经实现了一套即能满足数据驱动、又能满足Web关键字驱动的自动化框架&#xff08;主要基于 antjenkinstestngselenium webdriverjxl实现&#xff09;。通过这次的自动化框架开发&…

[CISCN2019 华北赛区 Day2 Web1]Hack World 布尔注入

正确的值 错误的值 我们首先fuzz一下 发现空格被过滤了 我们首先测试 (1)(1) (1)(2) 确定了是布尔注入了 我们写一下查询语句 (select(ascii(mid(flag,1,1))>1)from(flag))(select(ascii(mid(flag,1,1))102)from(flag)) 确定了f 开头 我们开始写脚本 import string …

喜获殊荣!迅镭激光获评“2023年苏州市质量奖”!

近日&#xff0c;苏州市质量奖评定委员会公示2023年苏州市质量奖评定结果&#xff0c;经过层层严格评审&#xff0c;迅镭激光从众多企业中脱颖而出&#xff0c;成功获评“苏州市质量奖”称号! 苏州市质量奖是苏州市政府设立&#xff0c;授予在经营质量上表现优秀的苏州企业的专…

交换机之间配置手动|静态链路聚合

两台交换机&#xff0c;配置链路聚合&#xff1a; 1、禁止自动协商速率&#xff0c;配置固定速率 int G0/0/1 undo negotiation auto speed 100int G0/0/2 undo negotiation auto speed 100 2、配置eth-trunk int eth-trunk 1 mode manual | lacp-staticint G0/0/1 eth-trun…

notion + nextjs搭建博客

SaaS可以通过博客来获得SEO流量&#xff0c;之前我自己在nextjs上&#xff0c;基于MarkDown Cloudfare来构建博客&#xff0c;很快我就了解到更优雅的方案&#xff1a;notion nextjs搭建博客&#xff0c;之前搭建了过&#xff0c;没有记录&#xff0c;这次刚好又要弄&#xf…

Yolov8-pose关键点检测:模型轻量化创新 | OREPA结合c2f,节省70%的显存!训练速度提高2倍! | CVPR2022

💡💡💡本文解决什么问题:浙大&阿里提出在线卷积重新参数化OREPA,节省70%的显存!训练速度提高2倍! OREPA | GFLOPs从9.6降低至8.2, mAP50从0.921提升至0.931 Yolov8-Pose关键点检测专栏介绍:https://blog.csdn.net/m0_63774211/category_12398833.html ✨✨…

平台登录页面实现(一)

文章目录 一、实现用户名、密码、登录按钮、记住用户表单1、全局css代码定义在asserts/css/global.css 二、用户名、密码、记住用户的双向绑定三、没有用户&#xff0c;点击注册功能实现四、实现输入用户名、密码、点击登录按钮进行登录操作五、实现表单项校验六、提交表单预验…

[python 刷题] 153 Find Minimum in Rotated Sorted Array

[python 刷题] 153 Find Minimum in Rotated Sorted Array 题目&#xff1a; Suppose an array of length n sorted in ascending order is rotated between 1 and n times. For example, the array nums [0,1,2,4,5,6,7] might become: [4,5,6,7,0,1,2] if it was rotated 4…