YOLOv5、YOLOv8改进:Decoupled Head解耦头

news2024/10/5 12:21:56

目录

 

1.Decoupled Head介绍

 2.Yolov5加入Decoupled_Detect

2.1 DecoupledHead加入common.py中:

2.2 Decoupled_Detect加入yolo.py中:

2.3修改yolov5s_decoupled.yaml 


1.Decoupled Head介绍

Decoupled Head是一种图像分割任务中常用的网络结构,用于提取图像特征并预测每个像素的类别。传统的图像分割网络通常将特征提取和像素预测过程集成在同一个网络中,而Decoupled Head则将这两个过程进行解耦,分别处理。

Decoupled Head的核心思想是通过引入额外的分支网络来进行像素级的预测。这个分支网络通常被称为“头”(head),因此得名Decoupled Head。具体而言,Decoupled Head网络在主干网络的特征图上添加一个或多个额外的分支,用于预测像素的类别。

Decoupled Head的优势在于可以更好地处理不同尺度和精细度的语义信息。通过将像素级的预测与特征提取分开,可以更好地利用底层和高层特征之间的语义信息,从而提高分割的准确性和细节保留能力。

Decoupled Head的优点:

  1. 分离特征提取和像素预测:Decoupled Head将特征提取和像素级预测分离开来,使得网络可以更加灵活地处理不同尺度和语义信息。

  2. 多尺度特征融合:通过在主干网络的不同层级添加分支,Decoupled Head可以融合来自不同尺度的特征信息,从而提高对多尺度目标的分割能力。

  3. 更好的像素级预测:由于Decoupled Head将像素级的预测作为独立的任务进行处理,可以更好地保留细节和边缘信息,提高分割的精确性。

  4. 可扩展性:Decoupled Head结构可以根据需要进行扩展和修改,例如添加更多的分支或调整分支的结构,以适应不同的任务和数据集需求。

 YOLOv6 采用了解耦检测头(Decoupled Head)结构,同时综合考虑到相关算子表征能力和硬件上计算开销这两者的平衡,采用 Hybrid Channels 策略重新设计了一个更高效的解耦头结构,在维持精度的同时降低了延时,缓解了解耦头中 3x3 卷积带来的额外延时开销。

原始 YOLOv5 的检测头是通过分类和回归分支融合共享的方式来实现的,因此加入 Decoupled Head。

为什么要用到解耦头?

因为分类和定位的关注点不同;
分类更关注目标的纹理内容;
定位更关注目标的边缘信息

 2.Yolov5加入Decoupled_Detect

2.1 DecoupledHead加入common.py中:

#======================= 解耦头=============================#
class DecoupledHead(nn.Module):
    def __init__(self, ch=256, nc=80,  anchors=()):
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.merge = Conv(ch, 256 , 1, 1)
        self.cls_convs1 = Conv(256 , 256 , 3, 1, 1)
        self.cls_convs2 = Conv(256 , 256 , 3, 1, 1)
        self.reg_convs1 = Conv(256 , 256 , 3, 1, 1)
        self.reg_convs2 = Conv(256 , 256 , 3, 1, 1)
        self.cls_preds = nn.Conv2d(256 , self.nc * self.na, 1) # 一个1x1的卷积,把通道数变成类别数,比如coco 80类(主要对目标框的类别,预测分数)
        self.reg_preds = nn.Conv2d(256 , 4 * self.na, 1)       # 一个1x1的卷积,把通道数变成4通道,因为位置是xywh
        self.obj_preds = nn.Conv2d(256 , 1 * self.na, 1)       # 一个1x1的卷积,把通道数变成1通道,通过一个值即可判断有无目标(置信度)
 
    def forward(self, x):
        x = self.merge(x)
        x1 = self.cls_convs1(x)
        x1 = self.cls_convs2(x1)
        x1 = self.cls_preds(x1)
        x2 = self.reg_convs1(x)
        x2 = self.reg_convs2(x2)
        x21 = self.reg_preds(x2)
        x22 = self.obj_preds(x2)
        out = torch.cat([x21, x22, x1], 1) # 把分类和回归结果按channel维度,即dim=1拼接
        return out
 
class Decoupled_Detect(nn.Module):
    stride = None  # strides computed during build
    onnx_dynamic = False  # ONNX export parameter
    export = False  # export mode
 
    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
 
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(DecoupledHead(x, nc, anchors) for x in ch)
        self.inplace = inplace  # use in-place ops (e.g. slice assignment)
 
    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
 
            if not self.training:  # inference
                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
 
                y = x[i].sigmoid()
                if self.inplace:
                    y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, -1, self.no))
 
        return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
 
    def _make_grid(self, nx=20, ny=20, i=0):
        d = self.anchors[i].device
        t = self.anchors[i].dtype
        shape = 1, self.na, ny, nx, 2  # grid shape
        y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
        if check_version(torch.__version__, '1.10.0'):  # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
            yv, xv = torch.meshgrid(y, x, indexing='ij')
        else:
            yv, xv = torch.meshgrid(y, x)
        grid = torch.stack((xv, yv), 2).expand(shape) - 0.5  # add grid offset, i.e. y = 2.0 * x - 0.5
        anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
        return grid, anchor_grid

2.2 Decoupled_Detect加入yolo.py中:

class BaseModel(nn.Module):

    def _apply(self, fn):
        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
        self = super()._apply(fn)
        m = self.model[-1]  # Detect()
        if isinstance(m, (Detect, Segment,Decoupled_Detect)):
            m.stride = fn(m.stride)
            m.grid = list(map(fn, m.grid))
            if isinstance(m.anchor_grid, list):
                m.anchor_grid = list(map(fn, m.anchor_grid))
        return self

class DetectionModel(BaseModel):

    def _initialize_dh_biases(self, cf=None):  # initialize biases into Detect(), cf is class frequency
        # https://arxiv.org/abs/1708.02002 section 3.3
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
        m = self.model[-1]  # Detect() module
        for mi, s in zip(m.m, m.stride):  # from
            # reg_bias = mi.reg_preds.bias.view(m.na, -1).detach()
            # reg_bias += math.log(8 / (640 / s) ** 2)
            # mi.reg_preds.bias = torch.nn.Parameter(reg_bias.view(-1), requires_grad=True)
 
            # cls_bias = mi.cls_preds.bias.view(m.na, -1).detach()
            # cls_bias += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # cls
            # mi.cls_preds.bias = torch.nn.Parameter(cls_bias.view(-1), requires_grad=True)
            b = mi.b3.bias.view(m.na, -1)
            b.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
            mi.b3.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
            b = mi.c3.bias.data
            b += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # cls
            mi.c3.bias = torch.nn.Parameter(b, requires_grad=True)
  if isinstance(m, (Detect, Segment,ASFF_Detect)):
            s = 256  # 2x min stride
            m.inplace = self.inplace
            forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
            m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))])  # forward
            check_anchor_order(m)
            m.anchors /= m.stride.view(-1, 1, 1)
            self.stride = m.stride
            self._initialize_biases()  # only run once
        elif isinstance(m, Decoupled_Detect):
            s = 256  # 2x min stride
            m.inplace = self.inplace
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
            check_anchor_order(m)  # must be in pixel-space (not grid-space)
            m.anchors /= m.stride.view(-1, 1, 1)
            self.stride = m.stride
            self._initialize_dh_biases()  # only run once

def parse_model(d, ch):  # model_dict, input_channels(3)

        elif m in {Detect, Segment,Decoupled_Detect}:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
            if m is Segment:
                args[3] = make_divisible(args[3] * gw, 8)

2.3修改yolov5s_decoupled.yaml 

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
 
# Parameters
nc: 1  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32
 
# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]
 
# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13
 
   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
 
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
 
   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
 
   [[17, 20, 23], 1, Decoupled_Detect, [nc, anchors]],  # Detect(P3, P4, P5),解耦
  ]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1032908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中科驭数DPU芯片K2斩获2023年“中国芯”优秀技术创新产品奖

2023年9月20日,中科驭数DPU芯片K2在2023年琴珠澳集成电路产业促进峰会暨第十八届“中国芯”颁奖仪式上荣获“中国芯”优秀技术创新产品奖。 “中国芯”集成电路优秀产品榜单是由国家工信部门指导、中国电子信息产业发展研究院举办的行业权威评选活动。自2006年以来…

oracle截取字符串前几位用substr函数如何操作?

随着社会的发展,it行业越来越受到人们的追捧,oracle软件作为一款数据库开发软件,更是受到it人士的钦懒,它是为数据存储和管理构建出的数据库管理系统,主要应用于商业智能管理、通信业务、工作流程管理等方面&#xff0…

springboot整合MeiliSearch轻量级搜索引擎

一、Meilisearch与Easy Search点击进入官网了解&#xff0c;本文主要从小微型公司业务出发&#xff0c;选择meilisearch来作为项目的全文搜索引擎&#xff0c;还可以当成来mongodb来使用。 二、starter封装 1、项目结构展示 2、引入依赖包 <dependencies><dependenc…

【操作系统笔记十四】科普:POSIX 是什么

注&#xff1a;本文转载自该文章posix是什么都不知道&#xff0c;还好意思说你懂Linux&#xff1f; Linux开发者越来越多&#xff0c;但是仍然有很多人整不明白POSIX是什么。本文就带着大家来了解一下到底什么是POSIX&#xff0c;了解他的历史和重要性。 一、什么是 POSIX&…

windows使用小技巧之windows照片查看器无法显示此图片

碰到过好几次了&#xff0c;以前没有理会&#xff0c;今天特意去查了一下解决方法&#xff0c;不然确实不太方便。 1、打开“颜色管理”-“高级”&#xff1a; 2、将“设备配置文件”选择为“Agfa&#xff1a;Swop Standard” 3、关闭&#xff0c;重新打开图片&#xff0c;好…

火花塞工作原理

1.红旗H9轿车2023款发布 2023年元旦过后&#xff0c;红旗汽车在人民大会堂举办了红旗H9的新车发布会&#xff0c;一汽红旗全新的H9豪华轿车终于出炉了全套的配置参数&#xff0c;红旗H9的车身长度达到5137mm&#xff0c;宽度1904mm&#xff0c;轴距3060mm&#xff0c;总高则控…

分享一个基于Python的电子产品销售系统可视化销量统计java版本相同(源码+调试+开题+lw)

&#x1f495;&#x1f495;作者&#xff1a;计算机源码社 &#x1f495;&#x1f495;个人简介&#xff1a;本人七年开发经验&#xff0c;擅长Java、Python、PHP、.NET、微信小程序、爬虫、大数据等&#xff0c;大家有这一块的问题可以一起交流&#xff01; &#x1f495;&…

windows 部署 mindspore GPU 开发环境

基础环境 windows 环境&#xff1a; Windows 11 版本&#xff1a;22H2操作系统版本&#xff1a;22621.2283 wsl2&#xff1a; 1.2.5.0 Docker Desktop&#xff1a; Docker Desktop 4.23.0 CUDA driver for WSL 版本&#xff1a; 535.104.07 宿主机上的 nvidia 环境如下所示&a…

【李沐深度学习笔记】矩阵计算(4)

课程地址和说明 线性代数实现p4 本系列文章是我学习李沐老师深度学习系列课程的学习笔记&#xff0c;可能会对李沐老师上课没讲到的进行补充。 本节是第四篇&#xff0c;由于CSDN限制&#xff0c;只能被迫拆分 矩阵计算 矩阵的导数运算 向量对向量求导的基本运算规则 已知…

如何实现线程池之间的数据透传 ?

如何实现线程池之间的数据透传 &#xff1f; 引言transmittable-thread-local概览capture如何 capture如何保存捕获的数据 save 和 replayrestore 小结 引言 当我们涉及到数据的全链路透传场景时&#xff0c;通常会将数据存储在线程的本地缓存中&#xff0c;如: 用户认证信息透…

灾备系统中的多线程传输功能

多线程传输是指同时使用多个线程进行文件传输&#xff0c;使多个数据包可以同时传输&#xff0c;从而充分利用网络带宽的最大值&#xff0c;提高传输速度。 正常的IE页面文件下载与上传都只有一个线程&#xff0c;有些软件可以实现多线程文件传输&#xff0c;就好像在传输文件…

JDK21你可以不用,新特性还是要了解的

大家好&#xff0c;我是风筝 今年6月份的时候&#xff0c;写过一篇JDK21引入协程&#xff0c;再也不用为并发而头疼了&#xff0c;那时候只是预览版&#xff0c;终于&#xff0c;前两天&#xff08;2023年9月19日&#xff09;发布了 JDK21 正式版。 老早就在 YouTube 上订阅了…

在电脑上怎么分类管理笔记?支持分类整理的电脑云笔记软件

对于大多数上班族而言&#xff0c;在使用电脑办公时&#xff0c;随手记录工作笔记是一个非常常见的场景。无论是会议纪要、工作总结还是项目计划&#xff0c;记录下每一次思考和灵感是提高工作效率的关键。然而&#xff0c;随着时间的推移&#xff0c;电脑上记录的笔记内容逐渐…

OceanMind海睿思入选弯弓研究院《2023中国营销技术生态图谱8.0》

近日&#xff0c;由国内MarTech领域知名机构 弯弓研究院 主办的第五届营销数字化大会暨营销科技MarTech交易展在广州成功召开。 本次大会发布了《2023中国营销技术生态图谱8.0版》 (以下简称“弯弓图谱8.0”)&#xff0c;中新赛克海睿思 凭借成熟的技术实力成功入选弯弓图谱8.0…

K-最近邻算法

一、说明 KNN算法是一个分类算法&#xff0c;基本数学模型是距离模型。K-最近邻是一种超级简单的监督学习算法。它可以应用于分类和回归问题。虽然它是在 1950 年代引入的&#xff0c;但今天仍在使用。然而如何实现&#xff0c;本文将给出具体描述。 来源&#xff1a;维基百科 …

Docker 安装Redis(集群)

3主3从redis集群配置 1、新建6个docker容器 redis 实例 docker run -d --name redis-node-1 --net host --privilegedtrue -v /data/redis/share/redis-node-1:/data redis:6.0.8 --cluster-enabled yes --appendonly yes --port 6381 docker run -d --name redis-node-2 --ne…

Fiddler抓取Https请求配置

官网&#xff1a;https://www.telerik.com/fiddler 配置抓取https包 1.Tools->Options->Https&#xff0c;勾选下面。 2.Actions ->Trust Root Certificate.安装证书到本地 3.在手机端设置代理&#xff1a;本机ip如&#xff1a;192.168.1.168 端口号:8888。 4.手机…

有一个新工具,能让程序员变成高手,优雅撸它!

不知道从什么时候开始&#xff0c;程序员这个职位变得家喻户晓&#xff0c;对程序员的印象也从以前的高深莫测变成如今的加班代名词。对于程序员加班&#xff0c;不懂有话要说。 作为大厂的一枚螺丝钉&#xff0c;接到任务的第一时间需要缕清底层逻辑&#xff0c;并随时关注部门…

【2603. 收集树中金币】

来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 描述&#xff1a; 给你一个 n 个节点的无向无根树&#xff0c;节点编号从 0 到 n - 1 。给你整数 n 和一个长度为 n - 1 的二维整数数组 edges &#xff0c;其中 edges[i] [ai, bi] 表示树中节点 ai 和 bi 之间有一条…

基于SpringBoot的网上超市系统的设计与实现

目录 前言 一、技术栈 二、系统功能介绍 管理员功能实现 用户功能实现 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 网络技术和计算机技术发展至今&#xff0c;已经拥有了深厚的理论基础&#xff0c;并在现实中进行了充分运用&#xff0c;尤其是基于计…