yolov8obb角度预测原理解析

news2024/12/23 10:50:51

预测头

ultralytics/nn/modules/head.py

class OBB(Detect):
    """YOLOv8 OBB detection head for detection with rotation models."""

    def __init__(self, nc=80, ne=1, ch=()):
        """Initialize OBB with number of classes `nc` and layer channels `ch`."""
        super().__init__(nc, ch)
        self.ne = ne  # number of extra parameters

        c4 = max(ch[0] // 4, self.ne)
        self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        bs = x[0].shape[0]  # batch size
        angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2)  # OBB theta logits
        # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
        angle = (angle.sigmoid() - 0.25) * math.pi  # [-pi/4, 3pi/4]
        # angle = angle.sigmoid() * math.pi / 2  # [0, pi/2]
        if not self.training:
            self.angle = angle
        x = Detect.forward(self, x)
        if self.training:
            return x, angle
        # return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

        return torch.cat([x, angle], 1).permute(0, 2, 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

forward 输入值
在这里插入图片描述

self.cv4网路结构

ModuleList(
  (0): Sequential(
    (0): Conv(
      (conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (1): Sequential(
    (0): Conv(
      (conv): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (2): Sequential(
    (0): Conv(
      (conv): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )

angle维度14,1,8400

损失函数

pred_angle = pred_angle.permute(0, 2, 1).contiguous()
维度变为14 8400 1

将预测结果转为bboxes

pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle)  # xyxy, (b, h*w, 4)

计算回归损失

loss[0], loss[2] = self.bbox_loss(
                pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
            )

这里的bbox_loss指的是:

self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)

接来下看RotatedBboxLoss

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

        # DFL loss
        if self.use_dfl:
            target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
            loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl

两个旋转矩形如何计算IOU:

def probiou(obb1, obb2, CIoU=False, eps=1e-7):
    """
    Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.

    Args:
        obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
        obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): A tensor of shape (N, ) representing obb similarities.
    """
    x1, y1 = obb1[..., :2].split(1, dim=-1)
    x2, y2 = obb2[..., :2].split(1, dim=-1)
    a1, b1, c1 = _get_covariance_matrix(obb1)
    a2, b2, c2 = _get_covariance_matrix(obb2)

    t1 = (
        ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
    ) * 0.25
    t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
    t3 = (
        ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
        / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
        + eps
    ).log() * 0.5
    bd = (t1 + t2 + t3).clamp(eps, 100.0)
    hd = (1.0 - (-bd).exp() + eps).sqrt()
    iou = 1 - hd
    if CIoU:  # only include the wh aspect ratio part
        w1, h1 = obb1[..., 2:4].split(1, dim=-1)
        w2, h2 = obb2[..., 2:4].split(1, dim=-1)
        v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
        with torch.no_grad():
            alpha = v / (v - iou + (1 + eps))
        return iou - v * alpha  # CIoU
    return iou

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1877314.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(笔记)Error: qemu-virgl: Failed to download resource “qemu-virgl--test-image“解决方法

错误: > Downloading https://www.ibiblio.org/pub/micro/pc-stuff/freedos/files/distributions/1.2/FD12FLOPPY.zip curl: (22) The requested URL returned error: 404Error: qemu-virgl: Failed to download resource "qemu-virgl--test-image" D…

基于QT开发的气体成分检测数据记录软件

1、软件概述 气体成分检测数据记录软件用于实现多种气体分析仪及相关设备实时数据的获取、存储和传送。目前支持的设备主要有气体分析仪、多通道进样阀箱、冷阱处理系统和气体采样处理系统。   气体成分检测数据记录软件可以根据实际应用需要进行配置,以实现不同应…

【PyQt】20-QTimer(动态显示时间、定时关闭)

QTimer 前言一、QTimer介绍二、动态时间展示2.1 代码2.2 运行结果 三、定时关闭3.1 介绍他的两种用法1、使用函数或Lambda表达式2、带有定时器类型(高级) 3.2 代码3.3 运行结果 总结 前言 好久没学习了。 一、QTimer介绍 pyqt里面的多线程可以有两种实…

使用 MongoDB 剖析开放银行:技术挑战和解决方案

开放银行(或开放金融)在银行业掀起了一股颠覆性浪潮,它迫使金融机构(银行、保险公司、金融科技公司、企业甚至政府机构)迎接一个透明、协作和创新的新时代。这种模式转变要求银行与第三方提供商(TPP&#x…

RAID0、RAID1、RAID5、RAID10、软RAID

硬盘 连续空间 无法 扩容 每个raid对应每个raid卡,没有阵列卡就不能用raid lvm 非连续空间 可以动态扩容 raid 备份, 提高读写性能,不能扩容 raid 是磁盘的集合,按照排列组合的方法不 一,给 raid 去了不同的名字…

Webpack: 构建微前端应用

Module Federation 通常译作“模块联邦”,是 Webpack 5 新引入的一种远程模块动态加载、运行技术。MF 允许我们将原本单个巨大应用按我们理想的方式拆分成多个体积更小、职责更内聚的小应用形式,理想情况下各个应用能够实现独立部署、独立开发(不同应用甚…

[OtterCTF 2018]Closure

既然你从内存中提取了密码,你能解密rick的文件吗? 密码是知道了,加密文件 ? flag 文件?dump 出来 已知这个勒索软件为HiddenTear,直接在网上找到解密程序HiddenTearDecrypter先将加密文件的末尾多余的0去掉…

javaScript利用indexOf()查找字符串的某个字符出现的位置

1 创建字符串 2 利用indexof()查询字符串的字符 3 利用while循环判断indexOf是否等于-1,不等于-1就打印一次并且索引号1去查下一个字符 //创建字符串var str1234567812311231;var indexstr.indexOf(1);//查询该字符while(index !-1)//indexOf()没有查到会返回-1{…

右键新建没有TXT文本文档的解决办法

电脑右键新建,发现没有txt了,我查网上办法都有点复杂,诸如注册表的,但是其实很简单,重启windows资源管理器就可以了。 点击重新启动,之后新建就有txt文档了。

基于Spring Boot的药房信息管理系统

1 项目介绍 1.1 研究的背景及意义 随着社会的飞速进步和药房行业竞争的白热化,传统的手工管理模式已难以适应药房信息管理的现代化需求。在计算机科学技术日臻完善的背景下,药房信息管理者们日益认识到运用计算机技术进行信息管理的迫切性和重要性。计…

昇思MindSpore学习总结五——网络构建

1、网络构建 神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子C…

创新探析:我国AIGC产业规模有望在2030年破万亿,创意设计行业或迎来全面革新

在科技日新月异的今天,人工智能生成内容(AIGC)与创意设计行业的结合正以前所未有的速度推动着产业变革。随着技术的不断突破和市场需求的日益增长,我国AIGC产业规模有望在2030年突破万亿元大关,这一宏伟目标不仅是对技…

VUE-CLI脚手架项目的初步创建与配置

首先创建一个VUE项目,注意选择版本为 2.6.10 打开APP.vue文件,并且删除APP.vue中多余的代码 创建index.vue文件 在此文件中写入如下图片中的代码来初步创建页面 创建router目录,并且创建index.js 文件如下 在终端输入npm run serve 运行 然后…

代码随想录-Day43

52. 携带研究材料(第七期模拟笔试) 小明是一位科学家,他需要参加一场重要的国际科学大会,以展示自己的最新研究成果。他需要带一些研究材料,但是他的行李箱空间有限。这些研究材料包括实验设备、文献资料和实验样本等…

[OtterCTF 2018]Recovery

里克必须找回他的文件!用于加密文件的随机密码是什么 恢复他的文件 ,感染的文件 ? vmware-tray.ex 前面导出的3720.dmp 查找一下 搜索主机 strings -e l 3720.dmp | grep “WIN-LO6FAF3DTFE” 主机名 后面跟着一串 代码 aDOBofVYUNVnmp7 是不…

c++类和对象(三)日期类

类和对象 一.拷贝构造函数定义二.拷贝构造函数特征三.const成员函数权限权限的缩小权限的缩放大 四.隐式类型转换 一.拷贝构造函数定义 拷贝构造函数:只有单个形参,该形参是对本类类型对象的引用(一般常用const修饰),在用已存 在的类类型对象…

韩顺平0基础学Java——第33天

p653-674 坦克大战 继续上回游戏 将每个敌人的信息,恢复成Node对象,放进Vector里面。 播放音乐 使用一个播放音乐的类。 第二阶段结束了 网络编程 相关概念 (权当是复习计网了) 网络 1.概念:两台或多台设备通过一定物理设备连…

Java基础知识-集合类

1、HashMap 和 Hashtable 的区别? HashMap 和 Hashtable是Map接口的实现类,它们大体有一下几个区别: 1. 继承的父类不同。HashMap是继承自AbstractMap类,而HashTable是继承自Dictionary类。 2. 线程安全性不同。Hashtable 中的方…

【Nginx】源码安装

nginx官网:nginx: download 选择文档版本安装即可 1.安装依赖包 //一键安装上面四个依赖 yum -y install gcc zlib zlib-devel pcre-devel openssl openssl-devel 2.下载并解压安装包 //创建一个文件夹 cd /usr/local mkdir nginx cd nginx //将下载的nginx压缩…

蓝卓出席“2024C?O大会”,探讨智能工厂建设新路径

6月29日,“2024C?O大会”在金华成功举办。此次大会由浙江省企业信息化促进会主办,与以往CIO峰会不同,“C?O”代表了企业数字化中的核心决策者群体,包括传统的CIO、CEO、CDO等。 本次大会围绕C?O、AIGC与制造业、数据价值、未来…