YoloV8损失函数篇(代码加理论)

news2025/1/9 19:10:49

首先yolov8中loss的权重可以在ultralytics/cfg/default.yaml修改在这里插入图片描述
损失函数定义ultralytics/utils/loss.py

回归分支的损失函数

  1. DFL(Distribution Focal Loss),计算anchor point的中心点到左上角和右下角的偏移量
  2. IoU Loss,定位损失,采用CIoU loss,只计算正样本的定位损失
target_bboxes /= stride_tensor
          loss[0], loss[2] = self.bbox_loss(
              pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
          )

分类损失:

  1. BCE loss,只计算正样本的分类损失。
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum

CIOU loss

  • 调用 loss 方法
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
iou函数
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
box1: 表示一个边界框,形状为(1, 4)的Tensor。
box2: 表示n个边界框,形状为(n, 4)的Tensor。
xywh: 如果为True,表示输入的框格式为(x, y, w, h)(中心点坐标和宽高);如果为False,则输入格式为(x1, y1, x2, y2)(左上角和右下角坐标)。
GIoU, DIoU, CIoU: 控制是否计算相应的IoU扩展版本。
eps: 一个小值,用于避免除零错误。
  • 计算交集面积
  1. 计算得到交集 h和w相乘
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) ).clamp_(0)
  1. 计算得到并集
union = w1 * h1 + w2 * h2 - inter + eps
  1. 计算IOU(交集/并集)
iou = inter / union
  1. 其中Yolo中使用的CIOU,补充CIOU内容
    • CIOU是IOU的基础上进行的计算IOU部分相同
  • 第一步还是计算IOU
  • 第二步,计算包围两个边界框的最小矩形s的h和w
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
  • 第三步,计算s的对角线的平方c2,两个边界框中心点之间距离的平方rho2
c2 = cw.pow(2) + ch.pow(2) + eps  # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)) / 4  # center dist**2
rho2推理过程:
	中心点cx:cx = (x1 + x2) / 2
	中心点cy:cy = (y1 + y2) / 2
	两个框中心点分别是 (cx1, cy1)(cx2, cy2)距离为:根号下(cx2-cx1)^2 + (cy2-cy1)^2
	距离的平方:cho2 = (cx2-cx1)^2 + (cy2-cy1)^2
	展开得:((x1 + x2) / 2-(x1_2 + x2_2) / 2)^2...
	整理得:1/4 * ((x1 + x2) -(x1_2 + x2_2))与代码一致
  • 第四步,计算一个与边界框宽高比相关的v,根据v计算权重alpha得到最终CIOU
v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
with torch.no_grad():
  alpha = v / (v - iou + (1 + eps))
  return iou - (rho2 / c2 + v * alpha)  # CIoU

这里就完成了iou loss计算过程

DFL loss

在这里插入图片描述

调用 dfl loss

if self.use_dfl:
   target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
   loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
   loss_dfl = loss_dfl.sum() / target_scores_sum
  • 第一步,调用bbox2dist将输入的坐标转换为距离中心的四个方向的距离l、t、r、b
    • anchor_points - x1y1 计算的是锚点到左上角 (x1, y1) 的水平距离和垂直距离,即 l 和 t。
      x2y2 - anchor_points 计算的是右下角 (x2, y2) 到锚点的水平距离和垂直距离,即 r 和 b。
dfl loss函数
def _df_loss(pred_dist, target):
        tl = target.long()  # target left
        tr = tl + 1  # target right
        wl = tr - target  # weight left
        wr = 1 - wl  # weight right
        return (F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr).mean(-1, keepdim=True)
  • tl和tr代表目标值的左右边界
  • wl是左边界的权重,wr是右边界的权重,如真实坐标值为8.3,那么它距离8近给一个大一点的权重0.7,距离9远,给一个小一点的权重0.3
  • 分别计算预测分布与左右目标值交叉熵损失并求和
  • 注:其中pred_dist[fg_mask]代表了目标的布尔分布,假设 self.reg_max = 7,这意味着我们对每个像素点的预测是在 0 到 7 的范围内的一组离散值,总共有8个可能的值。因此 pred_dist[fg_mask] 应该是形状为 (n, 8) 的张量,其中 n 是通过 fg_mask 选出来的前景位置的数量。view(-1, 8) 就是确保张量的每一行对应一个前景位置,并包含了所有 8 个可能的预测值。
  • 注2:代码中的anchor_points代表了一个格子是前景的中心点,l、t、r、b是anchor_points距离整个物体边界的距离
网络输出的pred_dist如何获得
  • 首先要得到anchorsstrides
    self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
def make_anchors(feats, strides, grid_cell_offset=0.5):
feats:输入特征图列表,通常是从不同的特征层中提取的特征图。
strides:特征图的下采样步长列表,对应每个特征图。
grid_cell_offset:用于调整锚点位置的偏移量,默认为 0.5,即锚点位于网格单元的中心。
for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)
h 和 w 分别表示特征图的高度和宽度。
sx 和 sy 分别是水平和垂直方向上的网格坐标(加上偏移量后,中心在网格单元中心)。
torch.meshgrid 用于生成二维网格坐标(sx 和 sy),这些坐标将构成锚点。
torch.stack((sx, sy), -1).view(-1, 2) 将 sx 和 sy 坐标组合为 (x, y) 坐标对,并展平为一个二维的锚点列表。
torch.full((h * w, 1), stride, dtype=dtype, device=device) 生成一个步长张量,与锚点数量匹配。
  • 对box进行如下卷积
class DFL(nn.Module):
    def __init__(self, c1=16):
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, _, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
        # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
conv不进行训练,权重是默认的,self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)),即[0, 1, 2, ..., c1-1]
forward:
这里将输入张量 x 重新调整形状为 (batch_size, 4, c1, num_anchors)4 表示每个锚点的4个回归值(通常是 l, t, r, b,即到边界的距离)。
transpose(2, 1):交换维度,将张量变为 (batch_size, c1, 4, num_anchors) 的形状。这使得通道维度 c1 排在第二位。
.softmax(1):对 c1 维度(即原来的通道维度)应用 softmax 操作。softmax 将每个类别的预测转换为概率分布,这在 DFL 中用于对每个边界框的预测进行更加细粒度的调整。
self.conv(...):使用 1x1 卷积对经过 softmax 的输出进行处理,实际上是对 softmax 结果的加权平均。由于卷积层的权重被初始化为 0 到 c1-1 的线性值,这一步相当于计算 softmax 结果的期望值,输出的每个通道的值可以被解释为最终预测的偏移量。
view(b, 4, a):将最终的输出张量调整回形状 (batch_size, 4, num_anchors),即每个锚点有4个回归值。
先整理到这里,后续补充…

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2078287.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

开源网络安全大模型 - SecGPT

网络安全大模型是指使用大量数据和参数来训练的人工智能模型,它可以理解和生成与网络安全相关的内容,例如漏洞报告、利用代码、攻击场景等。 目前各家网络安全厂商也纷纷跟进在大模型方面的探索,但可供广大从业者研究的特有网络安全大模型…

2013-2023年 中国MOD17A3H植被净初级生产力(NPP)数据

中国MOD17A3H植被净初级生产力(NPP)数据是基于NASA的MODIS卫星遥感数据计算得出的,这些数据对于评估生态系统碳收支、碳循环以及气候变化的影响具有重要意义。NPP数据可以反映植被通过光合作用固定大气中二氧化碳并转化为有机物质的能力&…

OpenStack组件介绍(2)

cinder 提供块存储服务,管理openstack中的块存储资源,为云平台提供持久的块存储服务,通过驱动的方式可以接入不同种类的后端存储。 cinder对接nfs 关闭防火墙和selinux [rootlocalhost yum.repos.d]# systemctl stop firewalld [rootlocal…

对想学习人工智能或者大模型技术从业者的建议

“ 技术的价值在于应用,理论与实践相结合才能事半功倍” 写这个关于AI技术的公众号也有差不多五个月的时间了,最近一段时间基本上都在保持日更状态,而且写的大部分都是关于大模型技术理论和技术方面的东西。‍‍‍‍‍‍‍‍‍ 然后最近一段…

网络安全售前入门04——审计类产品了解

目录 1.前言 2.数据库审计介绍 2.1产品架构功能 2.2应用场景 2.3部署形式 2.4产品价值 2.5选型依据 1.前言 为方便初接触网络安全售前工作的小伙伴了解网安行业情况,我制作一系统售前入门(安全产品,安全服务,法律法规等)文章介绍,希望能给初进网安职场的小伙伴提供…

STL中的stack与queue

前言: stack与queue是STL中的容器适配器,而不是容器。何为适配器?给手机充电的充电器就是一种适配器,将高电压变成低电压。适配器是用来做转化的,不用来直接管理数据,而是在其他容器的基础上去封装转换。 …

WordNet介绍——一个英语词汇数据库

传统语义知识库最常见的更新方法是依赖人工手动更新,使用这种更新方法的语义知识库包括最早的 WordNet、FrameNet和 ILD,以及包含丰富内容的 ConceptNet和 DBPedia。此类语义知识库的特点是以单词作为语义知识库的基本构成元素,以及使用预先设…

C++ | Leetcode C++题解之第376题摆动序列

题目&#xff1a; 题解&#xff1a; class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();if (n < 2) {return n;}int prevdiff nums[1] - nums[0];int ret prevdiff ! 0 ? 2 : 1;for (int i 2; i < n; i) {int diff n…

记一次NULL与空字符串导致的分组后产生重复数据

目录 一&#xff0c;场景说明二&#xff0c;实现功能三&#xff0c;修改原实现方法四&#xff0c;说明 一&#xff0c;场景说明 想实现这样一个功能&#xff0c;统计人员信息中不同性别的人的总工资。 实现方式&#xff1a;将数据group by 分组后累加。 二&#xff0c;实现功…

叉车(工业车辆)安全管理系统,云端监管人车信息运营情况方案

近年来&#xff0c;国家和各地政府相继出台了多项政策法规&#xff0c;从政策层面推行叉车智慧监管&#xff0c;加大叉车安全监管力度。同时鼓励各地结合实际&#xff0c;积极探索智慧叉车建设&#xff0c;实现作业人员资格认证、车辆状态认证、安全操作提醒、行驶轨迹监控等&a…

探秘Facebook的人工智能战略:如何用智能技术重塑社交网络

人工智能&#xff08;AI&#xff09;正以前所未有的速度渗透到各个领域&#xff0c;社交网络也不例外。作为全球最大的社交平台之一&#xff0c;Facebook&#xff08;现Meta&#xff09;正利用人工智能技术重塑其网络环境&#xff0c;提升用户体验。本文将深入探讨Facebook的人…

对SpringBoot项目Jar包进行加密防止反编译

最近项目要求部署到其他公司的服务器上,但是又不想将源码泄露出去,要求对正式环境的启动包进行安全性处理,防止客户直接通过反编译工具将代码反编译出来,本文介绍了如何对SpringBoot项目Jar包进行加密防止反编译,需要的朋友可以参考下 场景: 最近项目要求部署到其他公司的服…

华为HCIP-datacom 真题 (2024年下半年最新题库)

备考HCIP-datacom的小伙伴注意啦 2024年下半年8月份最新题库带解析,有需要的小伙伴移动至文章末 1.BGP 邻居建立过程的状态存在以下几种&#xff1a;那么建立一个成功的连接所经历的状态机顺序是 A、3-1-2-5-4 B、1-3-5-2-4 C、3-5-1-2-4 D、3-1-5-2-4 答案&#xff1a;D 解析…

界面控件DevExpress VCL v24.2路线图预览——增强云集成、简化应用程序皮肤等

DevExpress VCL Controls是Devexpress公司旗下老牌的用户界面套包&#xff0c;所包含的控件有&#xff1a;数据录入、图表、数据分析、导航、布局等。该控件能帮助您创建优异的用户体验&#xff0c;提供高影响力的业务解决方案&#xff0c;并利用您现有的VCL技能为未来构建下一…

el-pagination 下拉条目数最后一个样式改成全部

2024.08.27今天我学习了如何把el-pagination的下拉条目数修改&#xff0c;效果如下&#xff1a; 我们需要把最后一条选择换成全部展示&#xff0c;其实传给后端的还是总的数量&#xff0c;只是换了一个content&#xff0c; 通过f12查看元素可以拿到.el-select-dropdown__item …

华为鸿蒙NEXT大揭秘:微信版功能曝光,简洁界面回归

在科技界&#xff0c;每一次操作系统的更新迭代都是一场期待与猜测的盛宴。华为的鸿蒙系统自问世以来&#xff0c;就以其独特的设计理念和强大的功能吸引了全球的目光。而今&#xff0c;随着微信版鸿蒙NEXT的曝光&#xff0c;我们似乎又将迎来一次科技的飞跃。但这一次&#xf…

【多系统萎缩患者必看!】营养补给站,守护健康每一刻✨

Hey小伙伴们~ 今天我们来聊聊一个需要特别关爱的话题——多系统萎缩&#xff08;MSA&#xff09;患者的营养补充秘籍&#xff01;&#x1f31f; MSA是一种复杂的神经系统疾病&#xff0c;它影响我们的多个身体系统&#xff0c;让每一天的生活都充满了挑战。但别担心&#xff0c…

【工具】轻松解锁SQLite数据库,一窥微信聊天记录小秘密

前言 &#x1f34a;缘由 SQLite里藏秘密&#xff0c;微信聊天有痕迹 &#x1f423;闪亮主角 大家好&#xff0c;我是JavaDog程序狗 今天跟大家分享一个开源小工具PyWxDump&#xff0c;是一个用于获取 wx 账户信息&#xff08;昵称/账户/电话/电子邮件/数据库密钥&#xff0…

利用autoDecoder工具在数据包加密+签名验证站点流畅测试

站点是个靶场 https://github.com/0ctDay/encrypt-decrypt-vuls 演示地址http://39.98.108.20:8085/ 不是仅登录位置暴力破解的那种场景&#xff0c;使用autoDecoder&#xff08;https://github.com/f0ng/autoDecoder&#xff09;的好处就是每个请求自动加解密&#xff0c;测…

35岁失业后,这3个AI副业,也能养活自己

不少粉丝曾问我&#xff0c;有没有用AI赚钱的方法。 *眼看就快到35岁中年危机&#xff0c;想提前安排个退路。* 对于大家的焦虑&#xff0c;我很能理解&#xff0c;花钱容易挣钱难&#xff0c;尤其是在当下&#xff0c;大环境不是那么好&#xff0c;很多人进入佛系状态&#…