【mmdetection系列】mmdetection之loss讲解

news2024/9/25 12:20:34

目录

1.configs

2.具体实现

3.调用

3.1 注册

3.2 调用


配置部分在configs/_base_/models目录下,具体实现在mmdet/models/loss目录下。

1.configs

有的时候写在head中作为参数,有的时候head内部进行默认调用。 

我们以为例(这里没有直接写loss相关参数):

https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox/yolox_s_8x8_300e_coco.py#L17

2.具体实现

从head文件中,写了4个loss。

 以IoULoss为例,具体实现:

mmdetection/iou_loss.py at master · open-mmlab/mmdetection · GitHub

@LOSSES.register_module()
class IoULoss(nn.Module):
    """IoULoss.
    Computing the IoU loss between a set of predicted bboxes and target bboxes.
    Args:
        linear (bool): If True, use linear scale of loss else determined
            by mode. Default: False.
        eps (float): Eps to avoid log(0).
        reduction (str): Options are "none", "mean" and "sum".
        loss_weight (float): Weight of loss.
        mode (str): Loss scaling mode, including "linear", "square", and "log".
            Default: 'log'
    """

    def __init__(self,
                 linear=False,
                 eps=1e-6,
                 reduction='mean',
                 loss_weight=1.0,
                 mode='log'):
        super(IoULoss, self).__init__()
        assert mode in ['linear', 'square', 'log']
        if linear:
            mode = 'linear'
            warnings.warn('DeprecationWarning: Setting "linear=True" in '
                          'IOULoss is deprecated, please use "mode=`linear`" '
                          'instead.')
        self.mode = mode
        self.linear = linear
        self.eps = eps
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None,
                **kwargs):
        """Forward function.
        Args:
            pred (torch.Tensor): The prediction.
            target (torch.Tensor): The learning target of the prediction.
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None. Options are "none", "mean" and "sum".
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if (weight is not None) and (not torch.any(weight > 0)) and (
                reduction != 'none'):
            if pred.dim() == weight.dim() + 1:
                weight = weight.unsqueeze(1)
            return (pred * weight).sum()  # 0
        if weight is not None and weight.dim() > 1:
            # TODO: remove this in the future
            # reduce the weight of shape (n, 4) to (n,) to match the
            # iou_loss of shape (n,)
            assert weight.shape == pred.shape
            weight = weight.mean(-1)
        loss = self.loss_weight * iou_loss(
            pred,
            target,
            weight,
            mode=self.mode,
            eps=self.eps,
            reduction=reduction,
            avg_factor=avg_factor,
            **kwargs)
        return loss

3.调用

3.1 注册

注册创建类名字典。

@LOSSES.register_module()

3.2 调用

实例化在YOLOXHead该类中: 

https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/yolox_head.py#L109

self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.loss_obj = build_loss(loss_obj)

self.use_l1 = False  # This flag will be modified by hooks.
self.loss_l1 = build_loss(loss_l1)

 代码通过调用YOLOX的基类中的forward_train函数调用到head的forward_train:

https://github.com/open-mmlab/mmdetection/blob/31c84958f54287a8be2b99cbf87a6dcf12e57753/mmdet/models/detectors/single_stage.py#L57

    def forward_train(self,
                      img,
                      img_metas,
                      gt_bboxes,
                      gt_labels,
                      gt_bboxes_ignore=None):
        """
        Args:
            img (Tensor): Input images of shape (N, C, H, W).
                Typically these should be mean centered and std scaled.
            img_metas (list[dict]): A List of image info dict where each dict
                has: 'img_shape', 'scale_factor', 'flip', and may also contain
                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
                For details on the values of these keys see
                :class:`mmdet.datasets.pipelines.Collect`.
            gt_bboxes (list[Tensor]): Each item are the truth boxes for each
                image in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): Class indices corresponding to each box
            gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
                boxes can be ignored when computing the loss.
        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        super(SingleStageDetector, self).forward_train(img, img_metas)
        x = self.extract_feat(img)
        losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
                                              gt_labels, gt_bboxes_ignore)
        return losses

而head的forward_train,继承的其基类实现:

https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/base_dense_head.py#L303

def forward_train(self,
                      x,
                      img_metas,
                      gt_bboxes,
                      gt_labels=None,
                      gt_bboxes_ignore=None,
                      proposal_cfg=None,
                      **kwargs):
        """
        Args:
            x (list[Tensor]): Features from FPN.
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes (Tensor): Ground truth bboxes of the image,
                shape (num_gts, 4).
            gt_labels (Tensor): Ground truth labels of each box,
                shape (num_gts,).
            gt_bboxes_ignore (Tensor): Ground truth bboxes to be
                ignored, shape (num_ignored_gts, 4).
            proposal_cfg (mmcv.Config): Test / postprocessing configuration,
                if None, test_cfg would be used
        Returns:
            tuple:
                losses: (dict[str, Tensor]): A dictionary of loss components.
                proposal_list (list[Tensor]): Proposals of each image.
        """
        outs = self(x)
        if gt_labels is None:
            loss_inputs = outs + (gt_bboxes, img_metas)
        else:
            loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
        losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        if proposal_cfg is None:
            return losses
        else:
            proposal_list = self.get_bboxes(
                *outs, img_metas=img_metas, cfg=proposal_cfg)
            return losses, proposal_list

 其中调用的self.loss函数,基类中未实现,所以loss函数调用的head的:

losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)

mmdetection/yolox_head.py at master · open-mmlab/mmdetection · GitHub

@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses'))
    def loss(self,
             cls_scores,
             bbox_preds,
             objectnesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             gt_bboxes_ignore=None):
        """Compute loss of the head.
        Args:
            cls_scores (list[Tensor]): Box scores for each scale level,
                each is a 4D-tensor, the channel number is
                num_priors * num_classes.
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level, each is a 4D-tensor, the channel number is
                num_priors * 4.
            objectnesses (list[Tensor], Optional): Score factor for
                all scale level, each is a 4D-tensor, has shape
                (batch_size, 1, H, W).
            gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (list[Tensor]): class indices corresponding to each box
            img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                boxes can be ignored when computing the loss.
        """
        num_imgs = len(img_metas)
        featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
        mlvl_priors = self.prior_generator.grid_priors(
            featmap_sizes,
            dtype=cls_scores[0].dtype,
            device=cls_scores[0].device,
            with_stride=True)

        flatten_cls_preds = [
            cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
                                                 self.cls_out_channels)
            for cls_pred in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_objectness = [
            objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
            for objectness in objectnesses
        ]

        flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
        flatten_objectness = torch.cat(flatten_objectness, dim=1)
        flatten_priors = torch.cat(mlvl_priors)
        flatten_bboxes = self._bbox_decode(flatten_priors, flatten_bbox_preds)

        (pos_masks, cls_targets, obj_targets, bbox_targets, l1_targets,
         num_fg_imgs) = multi_apply(
             self._get_target_single, flatten_cls_preds.detach(),
             flatten_objectness.detach(),
             flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1),
             flatten_bboxes.detach(), gt_bboxes, gt_labels)

        # The experimental results show that ‘reduce_mean’ can improve
        # performance on the COCO dataset.
        num_pos = torch.tensor(
            sum(num_fg_imgs),
            dtype=torch.float,
            device=flatten_cls_preds.device)
        num_total_samples = max(reduce_mean(num_pos), 1.0)

        pos_masks = torch.cat(pos_masks, 0)
        cls_targets = torch.cat(cls_targets, 0)
        obj_targets = torch.cat(obj_targets, 0)
        bbox_targets = torch.cat(bbox_targets, 0)
        if self.use_l1:
            l1_targets = torch.cat(l1_targets, 0)

        loss_bbox = self.loss_bbox(
            flatten_bboxes.view(-1, 4)[pos_masks],
            bbox_targets) / num_total_samples
        loss_obj = self.loss_obj(flatten_objectness.view(-1, 1),
                                 obj_targets) / num_total_samples
        loss_cls = self.loss_cls(
            flatten_cls_preds.view(-1, self.num_classes)[pos_masks],
            cls_targets) / num_total_samples

        loss_dict = dict(
            loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj)

        if self.use_l1:
            loss_l1 = self.loss_l1(
                flatten_bbox_preds.view(-1, 4)[pos_masks],
                l1_targets) / num_total_samples
            loss_dict.update(loss_l1=loss_l1)

        return loss_dict

而对应的里面的loss函数,来自:

https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/yolox_head.py#L327

self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.loss_obj = build_loss(loss_obj)

self.use_l1 = False  # This flag will be modified by hooks.
self.loss_l1 = build_loss(loss_l1)

loss的具体实现在:https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/iou_loss.py#L16 

@LOSSES.register_module()
class IoULoss(nn.Module):
    """IoULoss.
    Computing the IoU loss between a set of predicted bboxes and target bboxes.
    Args:
        linear (bool): If True, use linear scale of loss else determined
            by mode. Default: False.
        eps (float): Eps to avoid log(0).
        reduction (str): Options are "none", "mean" and "sum".
        loss_weight (float): Weight of loss.
        mode (str): Loss scaling mode, including "linear", "square", and "log".
            Default: 'log'
    """

    def __init__(self,
                 linear=False,
                 eps=1e-6,
                 reduction='mean',
                 loss_weight=1.0,
                 mode='log'):
        super(IoULoss, self).__init__()
        assert mode in ['linear', 'square', 'log']
        if linear:
            mode = 'linear'
            warnings.warn('DeprecationWarning: Setting "linear=True" in '
                          'IOULoss is deprecated, please use "mode=`linear`" '
                          'instead.')
        self.mode = mode
        self.linear = linear
        self.eps = eps
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None,
                **kwargs):
        """Forward function.
        Args:
            pred (torch.Tensor): The prediction.
            target (torch.Tensor): The learning target of the prediction.
            weight (torch.Tensor, optional): The weight of loss for each
                prediction. Defaults to None.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
            reduction_override (str, optional): The reduction method used to
                override the original reduction method of the loss.
                Defaults to None. Options are "none", "mean" and "sum".
        """
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        if (weight is not None) and (not torch.any(weight > 0)) and (
                reduction != 'none'):
            if pred.dim() == weight.dim() + 1:
                weight = weight.unsqueeze(1)
            return (pred * weight).sum()  # 0
        if weight is not None and weight.dim() > 1:
            # TODO: remove this in the future
            # reduce the weight of shape (n, 4) to (n,) to match the
            # iou_loss of shape (n,)
            assert weight.shape == pred.shape
            weight = weight.mean(-1)
        loss = self.loss_weight * iou_loss(
            pred,
            target,
            weight,
            mode=self.mode,
            eps=self.eps,
            reduction=reduction,
            avg_factor=avg_factor,
            **kwargs)
        return loss

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/78571.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux timer浅析

linux timer 1、数据结构 1.1 timer_list struct timer_list {struct hlist_node entry;unsigned long expires;void (*function)(struct timer_list *);u32 flags;#ifdef CONFIG_LOCKDEPstruct lockdep_map lockdep_map; #endif };entry:定时器保存到哈希表中的节点&am…

QT+Python停车场车牌识别计费管理系统

程序示例精选 Python停车场车牌识别计费管理系统 如需安装运行环境或远程调试,见文章底部微信名片! 前言 QTPython是非常经典的窗体编程组合,功能完善,可视化界面美观易维护,这篇博客针对停车场车牌识别计费方面编写代…

JavaScript前端实用的工具函数封装

这篇文章主要为大家介绍了JavaScript前端实用的一些工具函数的封装,有需要的朋友可以借鉴参考下,希望能够有所帮助! 1.webpack里面配置自动注册组件 第一个参数是匹配路径,第二个是深度匹配,第三个是匹配规则 const requireComponent require.contex…

20-Django REST framework-Serializer序列化器

Serializer序列化器前言序列化器作用定义Serializer定义方法字段与选项创建Serializer对象序列化使用基本使用增加额外字段关联对象序列化反序列使用模型类序列化器ModelSerializer指定字段前言 本篇来学习Serializer序列化器知识 序列化器作用 进行数据的校验对数据对象进行…

[附源码]计算机毕业设计基于VUE的网上订餐系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

【代码审计-JAVA】基于javaweb框架开发的

目录 一、javaweb三大框架 1、Spring(开源分层的框架) 2、Struts(MVC设计模式) 3、Hibernate(开源的对象关系映射框架) 二、特征 1、结构 2、Servlet 三、重要文件 1、web.xml 2、pom.xml 3、web…

【文献研究】班轮联盟下合作博弈的概念

前言:以下是本人做学术研究时搜集整理的资料,供有相同研究需求的人员参考。 1. 合作博弈的一些概念 合作博弃中比较重要的问题是共赢状态下的利润分配问题,这关系到联盟的合作机制能否长期有效。这里首先介绍几个重要的概念: &…

174.Django中文件上传和下载

1. 文件上传和下载环境搭建 创建django项目和子应用urls中包含子应用,在子应用中创建urls.py配置数据库sqlite3(默认就是,无需配置)配置settings,上传文件目录编写模型代码(下面给出)模型的预迁…

如何使用Java获取货币符号?

1. 前言 最近做了一个支付相关的需求,要求在收银台页面显示商品的价格时带上货币符号¥,类似下图中的格式: 最初我是用的下面这样的代码: System.out.println(Currency.getInstance(Locale.CHINA).getSymbol());本机测…

postgresql_internals-14 学习笔记(一)

梳理一下之前理解不太清楚的知识点,重点内容可能会再拆出来单独研究。 原书链接:Index of / 一、 数据组织 1. pg系统库 template0:用于从逻辑备份还原,或创建不同字符集的数据库,不可以修改template1:真…

[附源码]Python计算机毕业设计SSM基于框架的毕业生就业管理系统(程序+LW)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

Unity脚本 (1) --- 创建脚本以及挂载脚本的本质,脚本模板的修改

值类型 --- 在栈区中开辟内存空间并直接存储在栈区中,引用类型 --- 在栈区中开辟内存空间存引用,在堆区中开辟内存空间存数据(有可能堆区中还要开辟引用),然后将堆区中存储数据的内存空间的地址传给引用接收 什么是脚本…

HTTP 请求走私

目录 0x01 简介 0x02 成因 2.1 Keep-Alive 2.2 Pipeline 2.3 Content-Length 2.4 Transfer-Encoding 0x03 分类 0x04. 攻击 4.1. CL不为0的GET请求 4.2 CL-CL 4.3 CL-TE 4.4 TE-CL 4.5. TE-TE 0x05 防御 参考资料: 0x01 简介 HTTP请求走私是一种干扰…

Mysql优化-全面详解(学习总结---从入门到深化)

Sql性能下降的原因 在程序的运行过程中,我们会发现这样的一个现象,随着程序运行 时间的不断推移以及数据量越来越大,程序响应的时间逐渐变慢, 程序变得卡顿,但最开始的时候并不是这样的,那是什么原因导致 的…

2023最新SSM计算机毕业设计选题大全(附源码+LW)之java毕业生回访系统564c4

最近发现近年来越来越多的人开始追求毕设题目的设创、和新颖性。以往的xx管理系统、xx校园系统都过时了。大多数人都不愿意做这类的系统了,确实,从有毕设开始就有人做了。但是很多人又不知道哪些毕设题目才算是新颖、创意。太老土的不想做,创…

RISC-V SiFive U64内核——HPM硬件性能监视器

学习、沉淀、分享,才能有所获~ 文章目录HPM简介性能监控计数器重置行为固定功能性能监控计数器事件可编程性能监控计数器事件选择器寄存器事件选择器编码计数使能寄存器对于性能分析,通常我们会使用Perf工具。而perf中的硬件事件,则需要硬件的…

阿里、腾讯、字节跳动大厂Java岗面试秘籍!(含答案解析)

本文主要是汇集整理了最新的阿里、腾讯、字节跳动大厂面试真题及答案解析,以及面试中被频繁问到的内容,主要作为参考大纲,供大家互相学习。 一、阿里篇(27题) 1.1.1 如何实现一个高效的单向链表逆序输出?…

当我用ChatGPT中学习CNN卷积神经网络时...

本文节选自本人博客:https://www.blog.zeeland.cn/archives/chatgpt-asoihgoihrx Introduction ChatGPT大火,在这一段时间并没有觉得ChatGPT特别厉害,最多就是一个基于生成式对话的NLP模型罢了,直到我看到了AI扮演Linux虚拟机&am…

[附源码]计算机毕业设计基于web的建设科技项目申报管理系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

【Redis】Redisson 分布式锁主从一致性问题

一、主从一致性问题的产生 Redis 主从集群使用如下: 在主节点进行数据的写操作;在节点进行数据的读操作;主节点向从节点同步数据。 主从一致性问题: 当主节点还没来得及将锁信息同步到从节点时,此时主节点宕机了。然…