(超详细)1-YOLOV5改进-Optimal Transport Assignment

news2024/12/26 21:00:18

Optimal Transport Assignment(OTA)是YOLOv5中的一个改进,它是一种更优的目标检测框架,可以在保证检测精度的同时,大幅提升检测速度。
话不多说,直接开始
1、首先在你的yolov5目录下面的找到loss.py文件
如下图:yolov5-master/utils/loss.py
在这里插入图片描述
在这里插入图片描述
2、在loss.py文件最后面加入这段代码,并保存。

import torch.nn.functional as F
from utils.metrics import box_iou
from utils.torch_utils import de_parallel
from utils.general import xywh2xyxy

class ComputeLossOTA:
    # Compute losses
    def __init__(self, model, autobalance=False):
        super(ComputeLossOTA, self).__init__()
        device = next(model.parameters()).device  # get model device
        h = model.hyp  # hyperparameters

        # Define criteria
        BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
        BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))

        # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
        self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))  # positive, negative BCE targets

        # Focal loss
        g = h['fl_gamma']  # focal loss gamma
        if g > 0:
            BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

        det = de_parallel(model).model[-1]  # Detect() module
        self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02])  # P3-P7
        self.ssi = list(det.stride).index(16) if autobalance else 0  # stride 16 index
        self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
        for k in 'na', 'nc', 'nl', 'anchors', 'stride':
            setattr(self, k, getattr(det, k))

    def __call__(self, p, targets, imgs):  # predictions, targets, model   
        device = targets.device
        lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
        bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs)
        pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]] for pp in p] 
    

        # Losses
        for i, pi in enumerate(p):  # layer index, layer predictions
            b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i]  # image, anchor, gridy, gridx
            tobj = torch.zeros_like(pi[..., 0], device=device)  # target obj

            n = b.shape[0]  # number of targets
            if n:
                ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets

                # Regression
                grid = torch.stack([gi, gj], dim=1)
                pxy = ps[:, :2].sigmoid() * 2. - 0.5
                #pxy = ps[:, :2].sigmoid() * 3. - 1.
                pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
                pbox = torch.cat((pxy, pwh), 1)  # predicted box
                selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
                selected_tbox[:, :2] -= grid
                iou = bbox_iou(pbox, selected_tbox, CIoU=True)  # iou(prediction, target)
                if type(iou) is tuple:
                    lbox += (iou[1].detach() * (1 - iou[0])).mean()
                    iou = iou[0]
                else:
                    lbox += (1.0 - iou).mean()  # iou loss

                # Objectness
                tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio

                # Classification
                selected_tcls = targets[i][:, 1].long()
                if self.nc > 1:  # cls loss (only if multiple classes)
                    t = torch.full_like(ps[:, 5:], self.cn, device=device)  # targets
                    t[range(n), selected_tcls] = self.cp
                    lcls += self.BCEcls(ps[:, 5:], t)  # BCE

                # Append targets to text file
                # with open('targets.txt', 'a') as file:
                #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]

            obji = self.BCEobj(pi[..., 4], tobj)
            lobj += obji * self.balance[i]  # obj loss
            if self.autobalance:
                self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()

        if self.autobalance:
            self.balance = [x / self.balance[self.ssi] for x in self.balance]
        lbox *= self.hyp['box']
        lobj *= self.hyp['obj']
        lcls *= self.hyp['cls']
        bs = tobj.shape[0]  # batch size

        loss = lbox + lobj + lcls
        return loss * bs, torch.cat((lbox, lobj, lcls)).detach()

    def build_targets(self, p, targets, imgs):
        indices, anch = self.find_3_positive(p, targets)
        device = torch.device(targets.device)
        matching_bs = [[] for pp in p]
        matching_as = [[] for pp in p]
        matching_gjs = [[] for pp in p]
        matching_gis = [[] for pp in p]
        matching_targets = [[] for pp in p]
        matching_anchs = [[] for pp in p]
        
        nl = len(p)    
    
        for batch_idx in range(p[0].shape[0]):
        
            b_idx = targets[:, 0]==batch_idx
            this_target = targets[b_idx]
            if this_target.shape[0] == 0:
                continue
                
            txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1]
            txyxy = xywh2xyxy(txywh)

            pxyxys = []
            p_cls = []
            p_obj = []
            from_which_layer = []
            all_b = []
            all_a = []
            all_gj = []
            all_gi = []
            all_anch = []
            
            for i, pi in enumerate(p):
                
                b, a, gj, gi = indices[i]
                idx = (b == batch_idx)
                b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx]                
                all_b.append(b)
                all_a.append(a)
                all_gj.append(gj)
                all_gi.append(gi)
                all_anch.append(anch[i][idx])
                from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device))
                
                fg_pred = pi[b, a, gj, gi]                
                p_obj.append(fg_pred[:, 4:5])
                p_cls.append(fg_pred[:, 5:])
                
                grid = torch.stack([gi, gj], dim=1)
                pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] #/ 8.
                #pxy = (fg_pred[:, :2].sigmoid() * 3. - 1. + grid) * self.stride[i]
                pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i] #/ 8.
                pxywh = torch.cat([pxy, pwh], dim=-1)
                pxyxy = xywh2xyxy(pxywh)
                pxyxys.append(pxyxy)
            
            pxyxys = torch.cat(pxyxys, dim=0)
            if pxyxys.shape[0] == 0:
                continue
            p_obj = torch.cat(p_obj, dim=0)
            p_cls = torch.cat(p_cls, dim=0)
            from_which_layer = torch.cat(from_which_layer, dim=0)
            all_b = torch.cat(all_b, dim=0)
            all_a = torch.cat(all_a, dim=0)
            all_gj = torch.cat(all_gj, dim=0)
            all_gi = torch.cat(all_gi, dim=0)
            all_anch = torch.cat(all_anch, dim=0)
        
            pair_wise_iou = box_iou(txyxy, pxyxys)

            pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)

            top_k, _ = torch.topk(pair_wise_iou, min(10, pair_wise_iou.shape[1]), dim=1)
            dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1)

            gt_cls_per_image = (
                F.one_hot(this_target[:, 1].to(torch.int64), self.nc)
                .float()
                .unsqueeze(1)
                .repeat(1, pxyxys.shape[0], 1)
            )

            num_gt = this_target.shape[0]
            cls_preds_ = (
                p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
                * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
            )

            y = cls_preds_.sqrt_()
            pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
               torch.log(y/(1-y)) , gt_cls_per_image, reduction="none"
            ).sum(-1)
            del cls_preds_
        
            cost = (
                pair_wise_cls_loss
                + 3.0 * pair_wise_iou_loss
            )

            matching_matrix = torch.zeros_like(cost, device=device)

            for gt_idx in range(num_gt):
                _, pos_idx = torch.topk(
                    cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False
                )
                matching_matrix[gt_idx][pos_idx] = 1.0

            del top_k, dynamic_ks
            anchor_matching_gt = matching_matrix.sum(0)
            if (anchor_matching_gt > 1).sum() > 0:
                _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
                matching_matrix[:, anchor_matching_gt > 1] *= 0.0
                matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
            fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
            matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
        
            from_which_layer = from_which_layer[fg_mask_inboxes]
            all_b = all_b[fg_mask_inboxes]
            all_a = all_a[fg_mask_inboxes]
            all_gj = all_gj[fg_mask_inboxes]
            all_gi = all_gi[fg_mask_inboxes]
            all_anch = all_anch[fg_mask_inboxes]
        
            this_target = this_target[matched_gt_inds]
        
            for i in range(nl):
                layer_idx = from_which_layer == i
                matching_bs[i].append(all_b[layer_idx])
                matching_as[i].append(all_a[layer_idx])
                matching_gjs[i].append(all_gj[layer_idx])
                matching_gis[i].append(all_gi[layer_idx])
                matching_targets[i].append(this_target[layer_idx])
                matching_anchs[i].append(all_anch[layer_idx])

        for i in range(nl):
            if matching_targets[i] != []:
                matching_bs[i] = torch.cat(matching_bs[i], dim=0)
                matching_as[i] = torch.cat(matching_as[i], dim=0)
                matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
                matching_gis[i] = torch.cat(matching_gis[i], dim=0)
                matching_targets[i] = torch.cat(matching_targets[i], dim=0)
                matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
            else:
                matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)

        return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs           

    def find_3_positive(self, p, targets):
        # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
        na, nt = self.na, targets.shape[0]  # number of anchors, targets
        indices, anch = [], []
        gain = torch.ones(7, device=targets.device).long()  # normalized to gridspace gain
        ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # same as .repeat_interleave(nt)
        targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices

        g = 0.5  # bias
        off = torch.tensor([[0, 0],
                            [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                            # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                            ], device=targets.device).float() * g  # offsets

        for i in range(self.nl):
            anchors = self.anchors[i]
            gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  # xyxy gain

            # Match targets to anchors
            t = targets * gain
            if nt:
                # Matches
                r = t[:, :, 4:6] / anchors[:, None]  # wh ratio
                j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t']  # compare
                # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
                t = t[j]  # filter

                # Offsets
                gxy = t[:, 2:4]  # grid xy
                gxi = gain[[2, 3]] - gxy  # inverse
                j, k = ((gxy % 1. < g) & (gxy > 1.)).T
                l, m = ((gxi % 1. < g) & (gxi > 1.)).T
                j = torch.stack((torch.ones_like(j), j, k, l, m))
                t = t.repeat((5, 1, 1))[j]
                offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
            else:
                t = targets[0]
                offsets = 0

            # Define
            b, c = t[:, :2].long().T  # image, class
            gxy = t[:, 2:4]  # grid xy
            gwh = t[:, 4:6]  # grid wh
            gij = (gxy - offsets).long()
            gi, gj = gij.T  # grid xy indices

            # Append
            a = t[:, 6].long()  # anchor indices
            indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1)))  # image, anchor, grid indices
            anch.append(anchors[a])  # anchors

        return indices, anch

3、第三步,找到train.py文件。
(1)将58行的ComputeLoss改成ComputeLossOTA(大家根据自己的代码行数看,有时候行数对不上,但找到对应代码就可以)
在这里插入图片描述
在这里插入图片描述
(2)再将258行的ComputeLoss改成ComputeLossOTA(大家根据自己的代码行数看,有时候行数对不上,但找到对应代码就可以)
在这里插入图片描述
(3)第三个要改的地方是313行的,加一个imgs。
在这里插入图片描述
上面就是train.py文件三个要改的地方

4、找到val.py
在这里插入图片描述
将214行加一个im,具体如图,之后要保存!
在这里插入图片描述
之后运行,结果涨了0.02!!!!!哈哈哈,涨点,确实涨了点啊啊啊!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1374055.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BGP公认任意属性——MED(二)

BGP公认任意属性有两个&#xff0c;分别是&#xff1a;Local-preference 和 MED&#xff0c;本期介绍MED。 点赞关注&#xff0c;持续更新&#xff01;&#xff01;&#xff01; MED 特点 MED &#xff08;多出口鉴别器&#xff09;&#xff0c;也称为BGP COST&#xff0c;…

设计模式之空对象模式

目录 1.简介 2.结构图 3.实例 4.优缺点 1.简介 空对象模式也是我们平时编程用的比较多的一种行为型设计模式&#xff0c;它的宗旨在解决空对象引起的异常报错问题&#xff1b;在空对象模式&#xff08;Null Object Pattern&#xff09;中&#xff0c;一个空对象取代 NULL 对…

Redis常用连接工具

RedisInsight 官网地址&#xff1a; RedisInsight | The Best Redis GUI Redis Desktop Manager 官网地址&#xff1a; RedisInsight | The Best Redis GUI 样式&#xff1a; QuickRedis 官网地址&#xff1a; QuickOfficial - QuickRedis 样式&#xff1a; AnotherRed…

SSL证书链是什么?SSL证书链如何工作?

SSL证书链作为公钥基础设施&#xff08;PKI&#xff09;的一项关键功能&#xff0c;它支持许多与安全相关的服务&#xff0c;包括数据机密性、数据完整性和最终实体身份验证&#xff0c;它使得互联网上的安全在线通信成为可能。那么SSL证书链是什么&#xff1f;SSL证书链如何工…

使用Python打造一个爱奇艺热播好剧提前搜系统

目录 一、系统功能设计 二、数据获取与处理 三、搜索功能实现 四、用户界面设计 五、系统部署与维护 六、总结 随着互联网的普及和人们对于娱乐需求的增加&#xff0c;视频网站成为了人们观看电视剧、电影等视频内容的主要渠道。爱奇艺作为国内知名的视频网站之一&#x…

自动化控制面板-1Panel

一、1Panel自动化控制面板 官网地址 1Panel 可以实现&#xff1a; 快速建站、高效管理、安全可靠、一键备份、应用商店 快速建站&#xff1a;深度集成 Wordpress 和 Halo&#xff0c;域名绑定、SSL 证书配置等一键搞定&#xff1b;高效管理&#xff1a;通过 Web 端轻松管理 …

Linux习题7

解析&#xff1a;du命令用于显示目录或文件的大小&#xff0c;du会显示指定的目录或文件所占用的磁盘空间。df命令用于显示目前在Linux系统上的文件系统磁盘使用情况统计。 解析&#xff1a;www是80&#xff0c;ftp是20,21 解析&#xff1a;光盘安装 (常规情况) 硬盘安装 (无光…

面试算法111:计算除法

题目 输入两个数组equations和values&#xff0c;其中&#xff0c;数组equations的每个元素包含两个表示变量名的字符串&#xff0c;数组values的每个元素是一个浮点数值。如果equations[i]的两个变量名分别是Ai和Bi&#xff0c;那么Ai/Bivalues[i]。再给定一个数组queries&am…

哪种小型洗衣机好用?高性价比的小型洗衣机推荐

大型洗衣机作为家居必备小家电&#xff0c;对生活品质的提升十分显著&#xff0c;在很多人的认知中&#xff0c;这种大型洗衣机主要是用来清洁大件的衣服和外套的&#xff0c;不方便将内衣裤都放入到里面&#xff0c;内衣裤的材质和尺寸都是比较特殊&#xff0c;若是直接将其放…

轻量化神奇!看3D模型格式转换工具HOOPS Exchange如何轻松实现减面操作?

现在很多CAD模型都比较复杂&#xff0c;有时候为了一些特殊用途&#xff08;轻量化显示、布尔运算、CAE网格剖分等&#xff09;&#xff0c;需要到对原始模型进行减面操作。在HOOPS Exchange中&#xff0c;就提供了对模型进行减面操作支持&#xff0c;以下内容就是HOOPS Exchan…

用TF-IDF处理文本数据

计算机擅长处理数字&#xff0c;但不擅长处理文本数据&#xff0c;TF-IDF是处理文本数据最广泛使用的技术之一&#xff0c;本文对它的工作原理以及它的特性进行介绍。 根据直觉&#xff0c;我们认为在文本数据分析中出现频率更高的单词应该具有更大的权重&#xff0c;但事实并…

python使用广度优先搜索算法解决二叉树最大、最小深度

对于广度优先搜索算法的一个经典应用问题&#xff0c;也就是对二叉树求其最大深度、最小深度问题。对于给定的二叉树的最大深度即为二叉树的根节点到最远的叶子结点之间的高度&#xff0c;而相应的最小深度就是根节点与离根节点最近的叶子节点之间的高度。 添加图片注释&#x…

MyBatisPlus学习笔记一

1、简介 MyBatisPlus&#xff08;简称MP&#xff09;是一个MyBatis的增强工具&#xff0c;在MyBatisMyBatisMyBatis的的基础上只做增强不做改变&#xff0c;为简化开发&#xff0c;提高效率而生。 官网&#xff1a;MyBatis-Plus mybatisplus通过扫描实体类&#xff0c;并基于…

贪心算法(思路)

最近在cf上做了很多贪心的题&#xff0c;写篇博客来总结一下 Problem - C - Codeforces 看第一道题 不难看出&#xff0c;我们需要在数组中找到一段奇偶相间的序列&#xff0c;要使他们的和最大&#xff0c; 在图中我们假设[1,2]和[3,4]是奇偶相间的序列&#xff0c;我们在在…

如何在 Microsoft Edge 浏览器中启用自动刷新

你是否经常发现自己在使用 Microsoft Edge 时点击刷新按钮&#xff1f;如果您需要一个网页以设定的时间间隔自动更新&#xff0c;那么请接着往下看。 在这篇博文中&#xff0c;我们探讨如何在 Microsoft Edge 浏览器中启用和管理自动刷新功能。 为什么选择自动刷新&#xff1…

【分布式】分布式链路跟踪技术

为什么需要分布式链路追踪 提到分布式链路追踪&#xff0c;我们要先提到微服务。相信很多人都接触过微服务。微服务是一种开发软件的架构和组织方法&#xff0c;它侧重将服务解耦&#xff0c;服务之间通过API通信。使应用程序更易于扩展和更快地开发&#xff0c;从而加速新功能…

使用requests库测试post请求 操作流程

第一步 谷歌f12或其他抓包工具抓包&#xff0c;这里随机抓一个post请求 url&#xff1a;https://eva2.csdn.net/v3/06981375190026432f77c01bfca33e32/lts/groups/dadde766-b087-42da-8e67-d2499a520ee7/streams/a0119567-bf91-4314-ab75-f683ba6c0c0a/logs 第二步 导包 impo…

国科大计算机体系结构期末考试——停更,手写更快

题型一、第二章的画图 给一个逻辑表达式&#xff0c;画出晶体管级别的电路图 cmos电路的基本电路&#xff1a; 与非门的功能是对多个输入信号进行逻辑与操作&#xff0c;然后对结果进行取反。 或非门的功能是对多个输入信号进行逻辑或操作&#xff0c;然后对结果进行取反。 …

适用于安防 音响 车载等产品中中的音频接口选型分析

在人工智能兴起之后&#xff0c;安防市场就成为了其全球最大的市场&#xff0c;也是成功落地的最主要场景之一。对于安防应用而言&#xff0c;智慧摄像头、智慧交通、智慧城市等概念的不断涌现&#xff0c;对于芯片产业催生出海量需求。今天&#xff0c;我将为大家梳理GLOBALCH…

怎么使用EIDE进行调试STM32单片机?

cortex-debug 用法 - Blog - Embedded IDE Forum (em-ide.com) 【VScode Embedded IDE】Keil工程导入VScode&#xff0c;与Keil协同开发MCU_vscode编辑keil工程-CSDN博客 Vscode EIDECortex Debug搭建STM32开发仿真环境_vscode cortex-debug-CSDN博客 可以结合一下上述三位大…