基于yolov5模型的目标检测蒸馏(LD+KD)

news2024/11/24 17:39:44

文章目录

  • 前言
  • 一、Distillation理解
    • 1、Knowlege distillation
    • 2、Feature distillation
    • 3、Location distillation
    • 4、其它蒸馏
  • 二、yolov5蒸馏模型构建
    • 1、构建teacher预测模型
    • 2、构建蒸馏loss
    • 3、蒸馏模型代码图示
      • 模型初始化
      • 模型蒸馏
  • 三、蒸馏模型实验
    • 1、工程数据测试
    • 2、voc2012开源数据测试


前言

最近在看有关蒸馏(Distillation)相关的内容,也就是需要大量的计算资源及庞大的数据集去支撑大模型,以蒸馏方式转为小模型,加速推理时间与降低模型内存,有利于部署。为此,我基于yolov5模型框架,修改代码,构建一个LD+KD的蒸馏模型,并公开源码于github,供读者学习。同时,我也正在构建多头蒸馏,后期将公开源码与文章解读。

源码链接:点击这里


一、Distillation理解

蒸馏是模型压缩方法,是通过教师模型知识传授学生模型的方法。一般教师模型是较大模型,效果较好,学生模型是较小模型,直接训练效果较差,
使用蒸馏模型传授教师知识,帮助提高学生模型性能。

1、Knowlege distillation

知识蒸馏(Knowledge Distillation,简记为 KD)是一种经典的模型压缩方法,分类模型论文较多,实际是蒸馏类型信息,通过teacher模型给出软标签给学生更多信息。如下图示意:
在这里插入图片描述

2、Feature distillation

特征蒸馏也是一种经典的模型压缩方法,实际是特征图的知识传递,通过teacher模型给出特征图给学生更多特征提取约束或信息量。如下图示意:

在这里插入图片描述

3、Location distillation

位置蒸馏也是一种经典的模型压缩方法,实际是位置点(如box)的知识传递,通过teacher模型给出预测位置给学生位置信息。该方法学术不多,比较新,如下图示意:
在这里插入图片描述

4、其它蒸馏

也有很多其它蒸馏方式,如通道蒸馏、无监督、对比等蒸馏方式,或最近bert蒸馏等。当然,介于我后期会出多头蒸馏文章,我引入论文图,如下:

在这里插入图片描述

二、yolov5蒸馏模型构建

我是基于yolov5模型蒸馏的,教师模型使用大尺寸模型m,学生模型使用小尺寸模型s。同时,我修改源码构建蒸馏模型结构,接下来我介绍如何构建基于yolov5模型构建蒸馏模型。其结构如下:
在这里插入图片描述

1、构建teacher预测模型

yolov5只需使用训练后的best.pt文件,通过attempt_load即可加载完预测模型初始化,至于attempt_load函数解析,相信很多博客已有说明,我不在解释,其teacher模型构建如下:

def create_teacher_model(weights,device):
    # device = torch.device('cuda:0')
    model=attempt_load(weights, map_location=device).eval()
    stride = int(model.stride.max())  # model stride
    names = model.module.names if hasattr(model, 'module') else model.names  # get class names
    teacher_model={'model':model,
                   'stride':stride,
                   'names':names
                   }
    return teacher_model

2、构建蒸馏loss

我基于yolov5模型构建LD+KD的蒸馏方式,借用yolov5原始模型loss计算方法,teacher模型输出为类的一个序列作为target类别,而原始yolov5模型
gt的target为类别数字非序列。为此,我们修改类别表示方式,使用序列替换数字,该位置在build_targets函数中,我做了大量修改,也将对应解释写在对应代码
中,其详情如下代码:

    def build_targets(self, p, targets):
        # Build targets for compute_loss(), input targets(image_id,class,x,y,w,h)
        na, nt = self.na, targets.shape[0]  #  每个点anchor数量(3), targets(每个batch中的标签个数)
        tcls, tbox, indices, anch ,tconf = [], [], [], [], [] # tcls表示类别,tbox表示box的坐标(x,y,w,h),indices表示图像索引,anch表示选取的anchor的索引
        gain = torch.ones(targets.shape[-1]+1, device=targets.device)  # normalized to gridspace gain
        ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # [na,nt] same as .repeat_interleave(nt)
        targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices
        # targets[image_id,x,y,w,h,conf,...cls,anchor_id]
        g = 0.5  # bias
        off = torch.tensor([[0, 0],
                            [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                            # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                            ], device=targets.device).float() * g  # offsets

        for i in range(self.nl):  # 循环3个特征层
            anchors, shape = self.anchors[i], p[i].shape
            gain[1:5] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain

            # Match targets to anchors
            t = targets * gain  # shape(3,n,7),在特征图中恢复gt尺寸,[img_id,x,y,w,h,conf,...cls,anchor_id]
            if nt:
                # Matches,选择正负样本方法,通过gt与anchor的wh比列筛选
                r = t[..., 3:5] / anchors[:, None]  # wh ratio
                j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']  # compare
                # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
                t = t[j]  # filter,通过筛除后获得正样本

                # Offsets 获取选择完成的box的*中心点*坐标-gxy(以图像左上角为坐标原点),并转换为以特征图右下角为坐标原点的坐标-gxi
                gxy = t[:, 1:3]  # grid xy
                gxi = gain[[1, 2]] - gxy  # inverse 特征图右下角为坐标原点
                # 分别判断box的(x,y)坐标是否大于1,并距离网格左上角的距离(准确的说是y距离网格上边或x距离网格左边的距离)小于0.5,
                # 如果(x,y)中满足上述两个条件,则选中.gxy.shape=[182,2],包含x,y,所以判别后转置得到j,k,2个结果
                # 对转换之后的box的(x,y)坐标分别进行判断是否大于1,并距离网格右下角的距离(准确的说是y距离网格下边或x距离网格右边的距离)距离小于0.5,
                # 如果(x,y)中满足上述两个条件,为Ture,
                j, k = ((gxy % 1 < g) & (gxy > 1)).T    # gxy>1,以左上角为坐标原点,表示排除上边与左边边缘格子
                l, m = ((gxi % 1 < g) & (gxi > 1)).T    # gxi>1同理,以右下角为坐标原点,排除右边与下边边缘格子
                j = torch.stack((torch.ones_like(j), j, k, l, m))  # 第一行为自己本身正样本值
                t = t.repeat((5, 1, 1))[j]  # 根据j挑选正样本,但未移动相邻网格
                offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]  # 根据j处理对应正样本偏置(确定移动相邻网格)
            else:
                t = targets[0]
                offsets = 0

            # Define  b=img_id,c=[...cls],conf=conf-->预测置信度 gxy=grid xy, gwh=grid wh, a=anchors_id
            b=t[:,0].long()
            c=t[:,6:-1]
            conf=t[:,5]
            gxy= t[:,1:3]
            gwh=t[:,3:5]
            a=t[:,-1].long()

            gij = (gxy - offsets).long()  # xy与offsets对应
            gi, gj = gij.T  # grid indices

            # Append
            indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image_id, anchor_id,与网格坐标grid_x,grid_y
            tbox.append(torch.cat((gxy - gij, gwh), 1))  # box 获取(x,y)相对于网格点的偏置,以及box的宽高
            anch.append(anchors[a])  # anchors  获得对应的anchor
            tcls.append(c)  # class 获得对应类别
            tconf.append(conf)

        return tcls, tbox, indices, anch,tconf

同时,我们也修改计算类别loss位置的one shot方式,yolov5原模型的target为数字需要转换one shot编码,而teacher模型给的target本身为序列标签,无需转换,因此修改内容如下:

原代码:

# Classification
if self.nc > 1:  # cls loss (only if multiple classes)
    t = torch.full_like(ps[:, 5:], self.cn, device=device)  # targets
    t[range(n), tcls[i]] = self.cp   # 这里将其one-short编码-->也说明类从0开始
    lcls += self.BCEcls(ps[:, 5:], t)  # BCE

修改代码:

lcls += self.BCEcls(ps[:, 5:], tcls[i])  # BCE

3、蒸馏模型代码图示

模型初始化

模型与loss的初始化,如下图示:
在这里插入图片描述

模型蒸馏

学生模型硬标签loss计算、teacher-student的软标签loss计算,如此实现yolov5的KD+LD蒸馏方式,如下图示列:
在这里插入图片描述

三、蒸馏模型实验

1、工程数据测试

教师模型使用yolov5m模型-学生与蒸馏模型使用yolov5s模型,测试结果如下:
在这里插入图片描述
PR曲线图:

在这里插入图片描述
map0.5与map0.5:0.95均比学生模型高一点点。

2、voc2012开源数据测试

进一步实验测试,采用开源数据测试。
教师模型使用yolov5m模型-学生与蒸馏模型使用yolov5s模型,测试结果如下:
在这里插入图片描述
PR曲线图:

在这里插入图片描述
蒸馏模型在map0.5表现较差0.007个点,但map0.5:0.95却高了0.004个点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/979252.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringMVC:从入门到精通,7篇系列篇带你全面掌握--二.SpringMVC常用注解及参数传递

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于SpringMVC的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.关于日志的了解 1.使用日志的好处…

ESD门禁管理系统的组成和主要功能

ESD门禁管理系统是一种用于实现企业或组织对出入口进行管理和控制的系统。ESD代表“电子门禁系统”&#xff0c;它利用先进的技术手段来确保只有授权人员可以进入特定区域&#xff0c;从而提高管理效率。 ESD门禁管理系统通常包括以下组件&#xff1a; 1. 门禁读卡器&#xf…

论文阅读《Robust Monocular Depth Estimation under Challenging Conditions》

论文地址&#xff1a;https://arxiv.org/pdf/2308.09711.pdf 源码地址&#xff1a;https://github.com/md4all/md4all 概述 现有SOTA的单目估计方法在理想的环境下能得到满意的结果&#xff0c;而在一些极端光照与天气的情况下往往会失效。针对模型在极端条件下的表现不佳问题&…

拓展世界 | “秀才”被封,千万粉丝一朝空,数字时代来临,大众情感寄托是否有新的出口?

近日&#xff0c;短视频千万粉丝博主“秀才”因违反平台相关规定被封&#xff0c;引起了不少网友的关注&#xff0c;网络上大家戏称他为“中年妇女收割机”&#xff0c;这次的封杀&#xff0c;网友开玩笑道“这得有多少阿姨伤心欲绝”。 在当今数字时代&#xff0c;网红主播已…

计算机领域期刊会议级别分类

文章目录 一、查询期刊1.1、知网1.2、letpub1.3、ccf 二、CCF2.1、CCF和SCI的区别2.2、国际学术期刊2.3、国内期刊2.4、国际会议2.5、国内会议 三、期刊会议总结 一、查询期刊 1.1、知网 查询中⽂期刊⼀般用知⽹&#xff0c;输入你想了解的期刊然后搜索&#xff0c;可以查看期…

Ab3d.DXEngine 6.0 Crack 2023

Ab3d.DXEngine 不是另一个游戏引擎&#xff08;如Unity&#xff09;&#xff0c;它强迫您使用其游戏编辑器、其架构&#xff0c;并且需要许多技巧和窍门才能在标准 .Net 应用程序中使用。Ab3d.DXEngine 是一个新的渲染引擎&#xff0c;它是从头开始构建的&#xff0c;旨在用于标…

计算机视觉的应用13-基于SSD模型的城市道路积水识别的应用项目

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下计算机视觉的应用13-基于SSD模型的城市道路积水识别的应用项目。今年第11号台风“海葵”后部云团的影响&#xff0c;福州地区的降雨量突破了历史极值&#xff0c;多出地方存在严重的积水。城市道路积水是造成交通拥…

关于ThreadPoolTaskExecutor线程池的配置

说明&#xff1a; 1、线程池分类、其他 1.1、分类 IO密集型 和 CPU密集型 任务的特点不同&#xff0c;因此针对不同类型的任务&#xff0c;选择不同类型的线程池可以获得更好的性能表现。 1.1. IO密集型任务 ​ IO密集型任务的特点是需要频繁读写磁盘、网络或者其他IO资源&a…

Netty—Channel

文章目录 一、Channel 是什么&#xff1f;&#x1f914;️二、 Channel 的继承体系&#x1f46a;三、Channel 的初始化过程 &#x1f50d;首先&#xff0c;channel() 指定 ChannelFactory 类型其次&#xff0c;Channel 实例化 一、Channel 是什么&#xff1f;&#x1f914;️ …

初识Java 5-1 实现隐藏

目录 库单元&#xff1a;package 代码组织 独一无二的包名 Java访问权限修饰符 包访问权限 接口访问权限&#xff08;public&#xff09; 不可访问&#xff08;private&#xff09; 继承访问权限&#xff08;protected&#xff09; 包访问权限与公共构造器 接口与实现…

基于Java+SpringBoot+Vue前后端分离医疗挂号管理系统设计和实现

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

pdf怎么转cad?几个简单方法分享给你

pdf怎么转cad&#xff1f;PDF文件转换为CAD文件是一项非常重要的任务&#xff0c;特别是对于那些需要进行工程、建筑和设计的专业人士来说。在过去&#xff0c;这项任务可能需要耗费大量时间和精力&#xff0c;但现在&#xff0c;随着技术的不断发展&#xff0c;已经有很多工具…

华为云云服务器评测| 之性能测试

文章目录 前言软件安装扩展知识 收集服务器负载信息指令解析开始压测后台运行 stress 运行 sysbench 测试网络带宽总结 测试磁盘 I/O 性能I/O 性能评估总结 前言 在当今数字化时代&#xff0c;云计算作为一种高效、灵活的计算方式&#xff0c;正日益受到企业和个人用户的广泛关…

如何远程访问Linux MeterSphere一站式开源持续测试平台

文章目录 前言1. 安装MeterSphere2. 本地访问MeterSphere3. 安装 cpolar内网穿透软件4. 配置MeterSphere公网访问地址5. 公网远程访问MeterSphere6. 固定MeterSphere公网地址 前言 MeterSphere 是一站式开源持续测试平台, 涵盖测试跟踪、接口测试、UI 测试和性能测试等功能&am…

基于Java+SpringBoot+Vue前后端分离校园商铺管理系统设计和实现

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

LeetCode 15 三数之和

题目链接 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 题目解析 // 1. 排序双指针 // 2. 固定一个值nums[i] 然后去剩下的位置去找 两数之和符合nums[j]nums[k]是否等于-nums[i] // 3. 细节问题&#xff1a;由于题目中是不可以包含重复的三元组的…

功率放大器的功能是什么功能

功率放大器是一种电子设备&#xff0c;用于放大输入信号的功率&#xff0c;并输出对应增强后的信号。功率放大器的功能主要包括增强信号的功率、保持信号的形状和质量、提供足够的电流和电压驱动负载&#xff0c;以满足不同应用需求。 功率放大器的主要功能是增强信号的功率。输…

阿里巴巴API接口解析,实现按关键字搜索商品

要解析阿里巴巴API接口并实现按关键字搜索商品&#xff0c;你需要进行以下步骤&#xff1a; 了解阿里巴巴API接口文档&#xff1a;访问阿里巴巴开放平台&#xff0c;找到API文档&#xff0c;了解阿里巴巴提供的API接口以及相关的参数、返回值等信息。注册开发者账号&#xff1…

远传水表和流量计的区别

远传水表和流量计是两种用于测量和控制水流的设备&#xff0c;虽然在某些方面有重叠的功能&#xff0c;但它们之间也有一些区别。下面我们将详细介绍这两种设备的区别。 一、定义和作用 远传水表是一种能够远程传输用水数据的水表&#xff0c;可以通过无线通信技术将数据传输到…

轻松解决Idea中maven无法下载源码

今天在解决问题的时候想要下载源码&#xff0c;突然发现idea无法下载&#xff0c;这是真的蛋疼&#xff0c;没办法查看原因&#xff0c;最后发现问题的原因居然是因为Maven&#xff0c;由于我使用的idea的内置的Bundle3的Maven&#xff0c;之前没有研究过本地安装和内置的区别&…