【知识蒸馏】YOLO object detection 逻辑蒸馏

news2024/11/22 21:27:21

YOLO检测蒸馏

和分类和分割蒸馏的差异:

由于YOLOv3检测框的位置输出为正无穷到负无穷的连续值,和上面将的分类离散kdloss不同,而且由于yolo是基于anchor的one stage模型,head out中99%都是背景预测。
Object detection at 200 Frames Per Second论文中指出,
直接在Yolo算法中引入distillation loss会有一些问题,因为目前的network distillation算法主要是针对RCNN系列的object detection算法(或者叫two stage系列)。对于two stage的object detection算法而言,其最后送给检测网络的ROI数量是很少的(默认是128个),而且大部分都是包含object的bbox,因此针对这些bbox引入distillation loss不会有太大问题。但是对于Yolo这类one stage算法而言,假设feature map大小是1313,每个grid cell预测5个bbox,那么一共就会生成1313*5=845个bbox,而且大部分都是背景(background)。如果将大量的背景区域传递给student network,就会导致网络不断去回归这些背景区域的坐标以及对这些背景区域做分类,这样训练起来模型很难收敛。因此,作者利用Yolo网络输出的objectness对distillation loss做一定限定,换句话说,只有teacher network的输出objectness较高的bbox才会对student network的最终损失函数产生贡献,这就是objectness scaled distillation。

原来Yolo算法的损失函数,包含3个部分(公式1):1、objectness loss,表示一个bbox是否包含object的损失;2、classification loss,表示一个bbox的分类损失;3、regression loss,表示一个bbox的坐标回归损失。

Yolo损失:回归损失+目标损失+分类损失,核心的算法如下图:
在这里插入图片描述

code

def distillation_output_MSEloss(outs, soft_outs):
    lambda_pi = 10
    loss_distillation = 0
    # pi = []
    # t_pi = []
    t_lcls , t_lbox, t_lobj = 0, 0, 0
    DboxLoss = nn.MSELoss(reduction="none")
    DclsLoss = nn.MSELoss(reduction="none")
    DobjLoss = nn.MSELoss(reduction="none")
    for index in range(len(outs[0])):
        num_grid_h = outs[0][index].size(2)
        num_grid_w = outs[0][index].size(3)
        pi = outs[0][index].view(-1,3,13,num_grid_h,num_grid_w).permute(0, 1, 3, 4, 2).contiguous()
        t_pi = soft_outs[0][index].view(-1,3,13,num_grid_h,num_grid_w).permute(0, 1, 3, 4, 2).contiguous()
        t_obj_scale = t_pi[..., 4].sigmoid()

        # BBox
        b_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4)
        t_lbox += torch.mean(DboxLoss(pi[..., :4], t_pi[..., :4]) * b_obj_scale)

        # Class
        c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 8)
        t_lcls += torch.mean(DclsLoss(pi[..., 5:], t_pi[..., 5:]) * c_obj_scale)

        #objectness
        t_lobj += torch.mean(DobjLoss(pi[..., 4], t_pi[..., 4]) * t_obj_scale)
    loss_distillation = t_lbox + t_lcls + t_lobj
    loss_distillation = lambda_pi * loss_distillation
    return loss_distillation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1937327.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【论文阅读笔记】Hierarchical Neural Coding for Controllable CAD Model Generation

摘要 作者提出了一种CAD的创新生成模型,该模型将CAD模型的高级设计概念表示为从全局部件排列到局部曲线几何的三层神经代码的层级树,并且通过指定目标设计的代码树来控制CAD模型的生成或完成。具体而言,一种带有“掩码跳过连接”的向量量化变…

【BUG】已解决:To update, run: python.exe -m pip install --upgrade pip

To update, run: python.exe -m pip install --upgrade pip 目录 To update, run: python.exe -m pip install --upgrade pip 【常见模块错误】 解决办法: 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰&…

「MQTT over QUIC」与「MQTT over TCP」与 「TCP 」通信测试报告

一、结论 在实车5G测试中「MQTT Over QUIC」整体表现优于「TCP」,可在系统架构升级时采用MQTT Over QUIC替换原有的TCP通讯;从实现原理上基于QUIC比基于TCP在弱网、网络抖动导致频繁重连场景延迟更低。 二、测试方案 网络类型:实车5G、实车…

FPGA-计数器

前言 之前一直说整理点FPGA控制器应用的内容,今天就从计数器这个在时序逻辑中比较重要的内容开始总结一下,主要通过还是通过让一个LED闪烁这个简单例子来理解。 寄存器 了解计数器之前先来认识一下寄存器。寄存器是时序逻辑设计的基础。时序逻辑能够避…

Android C++系列:Linux信号(三)

可重入函数 不含全局变量和静态变量是可重入函数的一个要素可重入函数见man 7 signal在信号捕捉函数里应使用可重入函数在信号捕捉函数里禁止调用不可重入函数例如:strtok就是一个不可重入函数,因为strtok内部维护了一个内部静态指针,保存上一 次切割到的位置,如果信号的捕捉…

android Invalid keystore format

签名的时候提示:Invalid keystore format. 点击info查看更多日志 再点击一次 stactrace 查看更多提示 提示:javaio异常 基本是jdk版本的问题,高jdk版本打的key,在低版本jdk开发环境上无法使用。 查看自己的key信息 keytool -list -v -keys…

Redis实现用户会话

1.分布式会话 (1)什么是会话 会话Session代表的是客户端与服务器的一次交互过程,这个过程可以是连续也可以是时断时续的。曾经的Servlet时代(jsp),一旦用户与服务端交互,服务器tomcat就会为用户创建一个session&#…

【C++】深入理解函数重载:C语言与C++的对比

文章目录 前言1. 函数重载:概念与条件1.1 什么是函数重载1.2 函数重载的条件1.3 函数重载的注意点 2. 函数重载的价值2.1 书写函数名方便2.2 类中构造函数的实现2.3 模板的底层实现 3. C语言与C的对比3.1 C语言不支持函数重载的原因3.2 C支持函数重载的原因 4. Linu…

PostgreSQL的引号、数据类型转换和数据类型

一、单引号和双引号(重要): 1、在mysql没啥区别 2、在pgsql中,实际字符串用单引号,双引号相当于mysql的,用来包含关键字; -- 单引号,表示user_name的字符串实际值 insert into t_user(user_nam…

AP ERP与汉得SRM系统集成案例(制药行业)

一、项目环境 江西某医药集团公司,是一家以医药产业为主营、资本经营为平台的大型民营企业集团。公司成立迄今,企业经营一直呈现稳健、快速发展的态势, 2008 年排名中国医药百强企业前 20 强,2009年集团总销售额约38亿元人民币…

vscode搭建PyQt + Quick开发环境

VScode搭建PyQt Quick开发环境 目录 环境准备 🔔安装必要的Python包 🔔🔎 PyQt5和PySide2的区别💾 安装PyQt5💾 安装PySide2 配置VScode 🔔💻 安装扩展 代码示例 🔔✔ Python调用Qt…

【JavaScript 算法】滑动窗口:处理子数组问题

🔥 个人主页:空白诗 文章目录 一、算法原理二、算法实现示例问题1:最长无重复字符子串示例问题2:长度最小的子数组注释说明: 三、应用场景四、总结 滑动窗口(Sliding Window)是一种高效解决数组…

Java多线程-----线程安全问题(详解)

目录 🍇一.线程安全问题的引入: 🍒二.线程安全问题产生的原因: 🍌三.如何解决线程安全问题: 🎉1.synchronized关键字: 🦉sychronized关键字的特性: ✨2.volatile关键字: &#…

03 Git的基本使用

第3章:Git的基本使用 一、创建版本仓库 一)TortoiseGit ​ 选择项目地址,右键,创建版本库 ​ 初始化git init版本库 ​ 查看是否生成.git文件(隐藏文件) 二)Git ​ 选择项目地址&#xff0c…

数据隔离级别查询一致导致重复退款

Transactionalpublic void updateAfsState() {String no "500001880002";OrderReturn orderReturnDb orderReturnModel.getOrderReturnByAfsSn(no);log.info("1.该售后单状态:{}" , orderReturnDb.getState());if(orderReturnDb.getState().e…

【人工智能】机器学习 -- 贝叶斯分类器

目录 一、使用Python开发工具,运行对iris数据进行分类的例子程序NaiveBayes.py,熟悉sklearn机器实习开源库。 1. NaiveBayes.py 2. 运行结果 二、登录https://archive-beta.ics.uci.edu/ 三、使用sklearn机器学习开源库,使用贝叶斯分类器…

vue使用了代理跨域,部署上线,使用Nginx配置出现问题,访问不到后端接口

1、如果路由的mode是history模式的要加上框框里的哪句,然后配置下面的location router location / {root /usr/local/app/dist/; #vue文件dist的完整路径try_files $uri $uri/ router;index index.html index.htm;}#error_page 500 502 503 504 /50x.html;lo…

缓存弊处的体验:异常

缓存(cache),它是什么东西,有神马用,在学习内存的时候理解它作为一个存储器,来对接cpu和内存,来调节cpu与内存的速度不匹配的问题。 缓存,一个偶尔可以听到的专业名词,全…

深入理解FFmpeg--软/硬件解码流程

FFmpeg是一款强大的多媒体处理工具,支持软件和硬件解码。软件解码利用CPU执行解码过程,适用于各种平台,但可能对性能要求较高。硬件解码则利用GPU或其他专用硬件加速解码,能显著降低CPU负载,提升解码效率和能效。FFmpe…

Leetcode双指针法应用

1.双指针法 文章目录 1.双指针法1.1什么是双指针法?1.2解题思路1.3扩展 1.1什么是双指针法? 双指针算法是一种在数组或序列上操作的技巧,实际上是对暴力枚举算法的一种优化,通常涉及到两个索引(或指针)从两…