【SSD 代码精读】之 model (Backbone) loss

news2024/9/21 14:44:18

model

  • 1、Backbone
    • 1)ResNet-50
    • 2)截取 ResNet-50 的前半部分作为 backbone
  • 2、Module
  • 3、Loss Function
    • 1)location loss
    • 2)confidence loss
    • 3)整体 loss
    • 4)loss 代码


1、Backbone

这里介绍使用 ResNet-50 作为 backbone (原论文使用的 backbone 是 VGG-16)

1)ResNet-50

https://blog.csdn.net/weixin_37804469/article/details/111773914

2)截取 ResNet-50 的前半部分作为 backbone

  • 截取到 layer 3 的 block 1 ,后面的丢弃不用
  • layer 3 的 block 1 要稍做修改,resnet-50 这里原本要做 downsample 的(图片尺寸减小一倍),要修改成不做downsample了。也就是 stride由之前的2 修改为 1, 如图
  • 按照原文,输入图片为 (3,300,300),那么,到 layer 3 的 block 1这里,输出的 feature map 尺寸为 (1024, 38, 38)。这个 feature map 会做为 多个特征层中的第一层。

在这里插入图片描述

class Backbone(nn.Module):
    def __init__(self, pretrain_path=None):
        super(Backbone, self).__init__()
        net = resnet50()
        self.out_channels = [1024, 512, 512, 256, 256, 256]

        if pretrain_path is not None:
            net.load_state_dict(torch.load(pretrain_path))

        self.feature_extractor = nn.Sequential(*list(net.children())[:7])

        conv4_block1 = self.feature_extractor[-1][0]

        # 修改conv4_block1的步距,从2->1
        conv4_block1.conv2.stride = (1, 1)
        conv4_block1.downsample[0].stride = (1, 1)

    def forward(self, x):
        x = self.feature_extractor(x)
        return x

到此,我们的backbone 就构建好了,backbone 就是 resent-50 前半截,下面我们在backbone 的基础上,继续搭建起 model


2、Module

在 backbone 的基础上,重新设计后半截的网络,构成完整的网络,用于特征层的输出,如下图
后面几层的特征层输出分别为 (512,19,19)、(512,10,10)、(256,5,5)、(256,3,3)(256,1,1)
在这里插入图片描述

def _build_additional_features(self, channels):
    additional_blocks = []

    # channels = [1024, 512, 512, 256, 256, 256]
    middle_channels = [256, 256 ,128, 128, 128]

    for i, (input_ch, output_ch, middle_ch) in enumerate(zip(channels[:-1], channels[1:], middle_channels)):
        padding, stride = (1, 2) if i < 3 else (0, 1)
        layer = nn.Sequential(
            nn.Conv2d(input_ch, middle_ch, kernel_size=1, bias=False),
            nn.BatchNorm2d(middle_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_ch, output_ch, kernel_size=3, padding=padding, stride=stride, bias=False),
            nn.BatchNorm2d(output_ch),
            nn.ReLU(inplace=True)
        )
        additional_blocks.append(layer)
        self.additional_blocks = nn.ModuleList(additional_blocks)


将 6 个 feature map 进一步进行位置提取 和 置信度提取 (location extractor & confidence extractor)
在这里插入图片描述

1) location extractor

从 6个 特征层中 提取 对应的 default box 的位置信息,其中 :

  • 5776、2166、600、150、36、4 分别表示每个特征层所对应的 default box 的个数
  • 4 就代表4个坐标 (ctr_x, ctr_y, width, height)的坐标参数的值
# confidence extractor
(1) Conv2d(1024, 4 * 4, kernel_size=3, padding=1))  ===>> (16, 38, 38)   ==view==>> (4, 5776)
(2) Conv2d(512, 6 * 4, kernel_size=3, padding=1))   ===>> (24, 19, 19)   ==view==>> (4, 2166)
(3) Conv2d(512, 6 * 4, kernel_size=3, padding=1))   ===>> (24, 10, 10)   ==view==>> (4, 600)
(4) Conv2d(256, 6 * 4, kernel_size=3, padding=1))   ===>> (24, 5, 5)     ==view==>> (4, 150)
(5) Conv2d(256, 4 * 4, kernel_size=3, padding=1))   ===>> (16, 3, 3)     ==view==>> (4, 36)
(6) Conv2d(256, 4 * 4, kernel_size=3, padding=1))   ===>> (16, 1, 1)     ==view==>> (4, 4)

====>> concatenate 得 (4, 8732)  : 表示所有 8732 个 default box 的位置参数

2)confidence extractor

从 6个 特征层中 提取 对应的 default box 中有object 的置信度,其中 :

  • 5776、2166、600、150、36、4 分别表示每个特征层所对应的 default box 的个数
  • 21 代表21个分类的置信度
# confidence extractor
(0): Conv2d(1024, 84, kernel_size=3, stride=1, padding=1) ===>> (84, 38, 38)   ==view==>> (21, 5776)
(1): Conv2d(512, 126, kernel_size=3, stride= 1, padding=1)===>> (126, 19, 19)   ==view==>> (21, 2166)
(2): Conv2d(512, 126, kernel_size=3, stride=1, padding=1)===>> (126, 10, 10)   ==view==>> (21, 600)
(3): Conv2d(256, 126, kernel_size= 3, stride=1, padding=1)===>> (126, 5, 5)   ==view==>> (21, 150)
(4): Conv2d(256, 84, kernel_size=3, stride=1, padding=1)===>> (84, 3, 3)   ==view==>> (21, 36)
(5): Conv2d(256, 84, kernel_size=3, stride=1, padding=1)===>> (84, 1, 1)   ==view==>> (21, 4)

====>> concatenate 得 (21, 8732):表示这 8732 个 default box 分别为 21个分类的概率

3、Loss Function

在得到了预测的 boxes 的坐标参数 和 置信度(分类概率),我门就要计算 loss 了。 Loss 的计算分为两个部分:

  • 坐标回归参数:坐标回归参数的 loss function 用的是 SmoothL1Loss
  • 分类:分类的 loss function 用的是 CrossEntropy
    \quad

1)location loss

1、将 gt boxes 的 (ctr_x, ctr_y, w, h) 形式的坐标 转换为 其相对于 default boxes 的回归参数

def _location_vec(self, loc):
	# (1) self.scale_xy = 10.0  ,  self.scale_wh = 5.0
	gxy = self.scale_xy * (loc[:, :2, :] - self.dboxes[:, :2, :]) / self.dboxes[:, 2:, :]  # Nx2x8732
	gwh = self.scale_wh * (loc[:, 2:, :] / self.dboxes[:, 2:, :]).log()  # Nx2x8732
	return torch.cat((gxy, gwh), dim=1).contiguous()

vec_gd = self._location_vec(gloc)

2、计算 预测回归参数:ploc 和 上一步转换出的 ground truth boxes 回归参数: vec_gd 的 SmoothL1 Loss
SmoothL1 Loss 的介绍在这里

vec_gd = self._location_vec(gloc)   # vec_gd shape=[N, 4, 8732]
loc_loss = nn.SmoothL1Loss(reduction='none')(ploc, vec_gd)   # loc_loss shape=[N, 4, 8732]

3、累加 4个 位置的参数 ,即 将每个 box 的 ctr_x、ctr_y、w、h 的 loss 进行相加

loc_loss = loc_loss.sum(dim=1)   # loc_loss shape=[N, 8732]

3、只提取出 正样本的 location loss ,即 只提取前景图的 location loss)

loc_loss = (mask.float() * loc_loss).sum(dim=1)   # loc_loss shape=[N]

\quad
\quad

2)confidence loss

选取负样本中 confidence loss 最大的 前k个

  • 负样本 :label=0 的背景图
  • 前k个指:k 由该图像中正样本的数量决定,要求选取图像中正样本数量3倍 的负样本
  • confidence loss 最大的负样本,即 在做 Hard negative mining,挖掘最难分类的负样本
       # hard negative mining Tenosr: [N, 8732]
        con = self.confidence_loss(plabel, glabel)

        # positive mask will never selected
        # 获取负样本
        con_neg = con.clone()
        con_neg[mask] = 0.0
        # 按照confidence_loss降序排列 con_idx(Tensor: [N, 8732])
        _, con_idx = con_neg.sort(dim=1, descending=True)
        _, con_rank = con_idx.sort(dim=1)  # 这个步骤比较巧妙

        # number of negative three times positive
        # 用于损失计算的负样本数是正样本的3倍(在原论文Hard negative mining部分),
        # 但不能超过总样本数8732
        neg_num = torch.clamp(3 * pos_num, max=mask.size(1)).unsqueeze(-1)
        neg_mask = torch.lt(con_rank, neg_num)  # (lt: <) Tensor [N, 8732]

        # confidence最终loss使用选取的正样本loss+选取的负样本loss
        con_loss = (con * (mask.float() + neg_mask.float())).sum(dim=1)  # Tensor [N]

        # avoid no object detected
        # 避免出现图像中没有GTBOX的情况
        total_loss = loc_loss + con_loss
        # eg. [15, 3, 5, 0] -> [1.0, 1.0, 1.0, 0.0]
        num_mask = torch.gt(pos_num, 0).float()  # 统计一个batch中的每张图像中是否存在正样本
        pos_num = pos_num.float().clamp(min=1e-6)  # 防止出现分母为零的情况
        ret = (total_loss * num_mask / pos_num).mean(dim=0)  # 只计算存在正样本的图像损失

3)整体 loss

整体 loss = location loss + confidence loss

total_loss = loc_loss + con_loss

计算 batch 中 N 张图像的 loss 平均值
(只计算存在正样本的图像损失,即:如果 batch 中存在 没有正样本的图像,则该图像不参与计算)

num_mask = torch.gt(pos_num, 0).float()  # 统计一个batch中的每张图像中是否存在正样本
pos_num = pos_num.float().clamp(min=1e-6)  # 防止出现分母为零的情况
ret = (total_loss * num_mask / pos_num).mean(dim=0)  # 只计算存在正样本的图像损失

4)loss 代码

class Loss(nn.Module):
    """
        Implements the loss as the sum of the followings:
        1. Confidence Loss: All labels, with hard negative mining
        2. Localization Loss: Only on positive labels
        Suppose input dboxes has the shape 8732x4
    """
    def __init__(self, dboxes):
        super(Loss, self).__init__()
        # Two factor are from following links
        # http://jany.st/post/2017-11-05-single-shot-detector-ssd-from-scratch-in-tensorflow.html
        self.scale_xy = 1.0 / dboxes.scale_xy  # scale_xy = 1 / 0.1 = 10,
        self.scale_wh = 1.0 / dboxes.scale_wh  # scale_wh = 1 / 0.2 = 5

        self.location_loss = nn.SmoothL1Loss(reduction='none')
        # [num_anchors, 4] -> [4, num_anchors] -> [1, 4, num_anchors]
        self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim=0),
                                   requires_grad=False)

        self.confidence_loss = nn.CrossEntropyLoss(reduction='none')

    def _location_vec(self, loc):
        # type: (Tensor) -> Tensor
        """
        Generate Location Vectors
        :param :
            (1) self.scale_xy = 10.0  ,  self.scale_wh = 5.0
            (2) default 匹配到的 gt box,    self.dboxes 就是row default box
        :return: ground truth相对anchors的回归参数
        """
        gxy = self.scale_xy * (loc[:, :2, :] - self.dboxes[:, :2, :]) / self.dboxes[:, 2:, :]  # Nx2x8732
        gwh = self.scale_wh * (loc[:, 2:, :] / self.dboxes[:, 2:, :]).log()  # Nx2x8732
        return torch.cat((gxy, gwh), dim=1).contiguous()

    def forward(self, ploc, plabel, gloc, glabel):
        # type: (Tensor, Tensor, Tensor, Tensor) -> Tensor
        """
            ploc, plabel: Nx4x8732, Nxlabel_numx8732
                predicted location and labels

            gloc, glabel: Nx4x8732, Nx8732
                ground truth location and labels
        """
        # 获取正样本的mask  Tensor: [N, 8732]
        mask = torch.gt(glabel, 0)  # (gt: >)
        # mask1 = torch.nonzero(glabel)
        # 计算一个batch中的每张图片的正样本个数 Tensor: [N]
        pos_num = mask.sum(dim=1)

        # 计算gt的location回归参数 Tensor: [N, 4, 8732]
        vec_gd = self._location_vec(gloc)

        # sum on four coordinates, and mask
        # 计算定位损失(只有正样本)
        loc_loss = self.location_loss(ploc, vec_gd).sum(dim=1)  # Tensor: [N, 8732]
        loc_loss = (mask.float() * loc_loss).sum(dim=1)  # Tenosr: [N]

        # hard negative mining Tenosr: [N, 8732]
        con = self.confidence_loss(plabel, glabel)

        # positive mask will never selected
        # 获取负样本
        con_neg = con.clone()
        con_neg[mask] = 0.0
        # 按照confidence_loss降序排列 con_idx(Tensor: [N, 8732])
        _, con_idx = con_neg.sort(dim=1, descending=True)
        _, con_rank = con_idx.sort(dim=1)  # 这个步骤比较巧妙

        # number of negative three times positive
        # 用于损失计算的负样本数是正样本的3倍(在原论文Hard negative mining部分),
        # 但不能超过总样本数8732
        neg_num = torch.clamp(3 * pos_num, max=mask.size(1)).unsqueeze(-1)
        neg_mask = torch.lt(con_rank, neg_num)  # (lt: <) Tensor [N, 8732]

        # confidence最终loss使用选取的正样本loss+选取的负样本loss
        con_loss = (con * (mask.float() + neg_mask.float())).sum(dim=1)  # Tensor [N]

        # avoid no object detected
        # 避免出现图像中没有GTBOX的情况
        total_loss = loc_loss + con_loss
        # eg. [15, 3, 5, 0] -> [1.0, 1.0, 1.0, 0.0]
        num_mask = torch.gt(pos_num, 0).float()  # 统计一个batch中的每张图像中是否存在正样本
        pos_num = pos_num.float().clamp(min=1e-6)  # 防止出现分母为零的情况
        ret = (total_loss * num_mask / pos_num).mean(dim=0)  # 只计算存在正样本的图像损失
        return ret
        
compute_loss = Loss(default_box)
loss = compute_loss(locs, confs, bboxes_out, labels_out)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/357886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

思维经验 | ​如何刻意练习提升用户思维?

小飞象交流会哪里有什么捷径&#xff0c;那些个“一步登天”的哪个不是在前面就打好了“地基”的。内部交流│20期思维经验分享如何刻意练习提升用户思维&#xff1f;data analysis●●●●分享人&#xff1a;大江老师‍数据部门和运营部门做了大量的用户标签和用户分层工作。为…

基于GIS的地下水脆弱性评价

&#xff08;一&#xff09;行政边界数据、土地利用数据和土壤类型数据 本文所用到的河北唐山行政边界数据、土地利用数据和土壤类型数据均来源于中国科学院资源环境科学与数据中心&#xff08;https://www.resdc.cn/Default.aspx&#xff09;。 &#xff08;二&#xff09;地…

小孩扁桃体肿大3度能自愈吗?6岁小孩扁桃体肥大怎么治效果好?

12月7日&#xff0c;四川眉山市民唐先生说&#xff0c;他刚出生的儿子在妇产医院分娩中心住了20天后感染了败血症。据唐先生介绍&#xff0c;哈子出院时各项指标正常。他在分娩中心住了半个月左右&#xff0c;孩子喝牛奶很生气&#xff0c;第二天就开始发烧了。同一天&#xff…

新版bing(集成ChatGPT)申请通过后在谷歌浏览器(Chrome)上的使用方法

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,科大讯飞比赛第三名,CCF比赛第四名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

Flink-运行时架构介绍

文章目录Flink 运行时架构系统架构整体架构作业管理器&#xff08;JobManager&#xff09;任务管理器&#xff08;TaskManager&#xff09;作业提交流程高层级抽象视角独立模式&#xff08;Standalone&#xff09;YARN 集群数据流图&#xff08;Dataflow Graph&#xff09;并行…

用户画像——如何构建用户画像系统

为什么需要用户画像 如果你是用户,当你使用抖音、今日头条的时候,如果平台推荐给你的内容都是你感兴趣的,能够为你节省大量搜索内容的时间。 如果你是商家,当你投放广告的时候,如果平台推送的用户都是你的潜在买家,能够使你花更少的钱,带来更大的收益。 这两者背后都…

Linux内核驱动之efi-rtc

Linux内核驱动之efi-rtc1. UEFI与BIOS概述1.1. BIOS 概述1.1.1. BIOS缺点&#xff1a;1.1.2. BIOS的启动流程1.2 UEFI 概述1.2.1 Boot Sevices&#xff1a;1.2.2. Runtime Service&#xff1a;1.2.3. UEFI优点&#xff1a;1.2.4. UEFI启动过程&#xff1a;1.3 Legacy和UEFI1.4 …

【第31天】SQL进阶-写优化- 插入优化(SQL 小虚竹)

回城传送–》《31天SQL筑基》 文章目录零、前言一、练习题目二、SQL思路&#xff1a;SQL进阶-写优化-插入优化解法插入优化禁用索引语法如下适用数据库引擎非空表&#xff1a;禁用索引禁用唯一性检查语法如下适用数据库引擎禁用外键检查语法如下适用数据库引擎批量插入数据语法…

W806(一)模拟IIC驱动0.96OLED[移植]

前言平头哥内核的国产开发板&#xff0c;资源丰富&#xff0c;按照官方的描述是能够吊打STM32F103C8T6的&#xff0c;22年刚发布的时候就买了&#xff0c;但是当时忙于考研&#xff0c;而且开发板的SDK不够完善&#xff0c;所以23年来填一下坑&#xff0c;今年我在官方群里找到…

ChatGPT原理与技术演进剖析

—— 要抓住一个风口&#xff0c;你得先了解这个风口的内核究竟是什么。本文作者&#xff1a;黄佳 &#xff08;著有《零基础学机器学习》《数据分析咖哥十话》&#xff09; ChatGPT相关文章已经铺天盖地&#xff0c;剖析&#xff08;现阶段或者只能说揣测&#xff09;其底层原…

为啥程序会有bug?

这是一篇半娱乐性的吐槽文章&#xff0c;权当给广大技术人员解解闷&#xff1a;&#xff09;。哈哈哈&#xff0c;然后我要开始讲一个经常在发生的事实了。&#xff08;程序员们可能会感到一些不适&#xff09;99.999999999%做技术的都会被问到或者被吐槽到&#xff1a;“你的程…

PPT和回放来了 | 中国PostgreSQL数据库生态大会圆满落幕

2月17-19日&#xff0c;中国PostgreSQL数据库生态大会在北京中科院软件所和CSDN平台以线下线上结合方式盛大召开&#xff01;本届大会由中国开源软件推进联盟PostgreSQL分会主办。作为自2017年后我们举办的第六届年度行业大会&#xff0c;延续了传播技术&#xff0c;发展产业生…

代码质量与安全 | 一文了解高级驾驶辅助系统(ADAS)及其开发中需要遵循的标准

高级驾驶辅助系统&#xff08;ADAS&#xff09;有助于提高车内每个人的安全性&#xff0c;帮助他们安全抵达目的地。这项技术功能非常重要&#xff0c;因为大多数的严重车祸都是人为错误造成的。 本篇文章将讨论什么是高级驾驶辅助系统&#xff08;ADAS&#xff09;&#xff0…

Linux基础-新手必备命令

概述系统工作系统状态检测文件与目录管理文件内容查阅与编辑打包压缩与搜索常见命令图解参考资料概述常见执行 Linux 命令的格式是这样的&#xff1a;命令名称 [命令参数] [命令对象]注意&#xff0c;命令名称、命令参数、命令对象之间请用空格键分隔。命令对象一般是指要处理的…

第三章 Kafka生产问题总结及性能优化实践

第三章 Kafka生产问题总结及性能优化实践 1、线上环境规划 JVM参数设置 kafka 是 scala 语言开发&#xff0c;运行在 JVM 上&#xff0c;需要对 JVM 参数合理设置&#xff0c;参看 JVM 调优专题 修改 bin/kafka-start-server.sh 中的 JVM 设置&#xff0c;假设机器是 32G 内…

Java有几种文件拷贝方式?哪一种最高效?

第12讲 | Java有几种文件拷贝方式&#xff1f;哪一种最高效&#xff1f; 我在专栏上一讲提到&#xff0c;NIO 不止是多路复用&#xff0c;NIO 2 也不只是异步 IO&#xff0c;今天我们来看看 Java IO 体系中&#xff0c;其他不可忽略的部分。 今天我要问你的问题是&#xff0c;…

一个巨型的ESP8266模块,围观围观

作者&#xff1a;晓宇&#xff0c;排版&#xff1a;晓宇微信公众号&#xff1a;芯片之家&#xff08;ID&#xff1a;chiphome-dy&#xff09;01 巨型ESP8266ESP8266几乎无人不知&#xff0c;无人不晓了吧&#xff0c;相当一部分朋友接触物联网都是从ESP8266开始的&#xff0c;…

软考中级-嵌入式系统设计师(二)

1、逻辑电路&#xff1a;组合逻辑单路、时序逻辑电路。根据电路是否有存储功能判断。 2、组合逻辑电路 指该电路在任一时刻的输出&#xff0c;仅取决于该时刻的输入信号&#xff0c;而与输入信号作用前电路的状态无关。一般由门电路组成&#xff0c;不含记忆元器件&#xff0…

XD文件转换为sketch的三种方法

XD文件如何转化为Sketch文件&#xff0c;作为竞品的两个产品&#xff0c;如果要互通到可以彼此转换为彼此的文件格式&#xff0c;还是有点难的。所以&#xff0c;今天我总结了 3 个方法&#xff0c;其中最后一个方法是最好用的&#xff01; XD 和 Sketch 算是竞品&#xff0c;想…

论文笔记:TIMESNET: TEMPORAL 2D-VARIATION MODELINGFOR GENERAL TIME SERIES ANALYSIS

ICLR 2023 1 intro 时间序列一般是连续记录的&#xff0c;每个时刻只会记录一些标量 之前的很多工作着眼于时间维度的变化&#xff0c;以捕捉时间依赖关系 ——>可以反映出、提取出时间序列的很多内在特征&#xff0c;比如连续性、趋势、周期性等但是现实时间序列数据中的…