【yolov1】yoloLoss.py

news2024/11/25 12:26:17

在这里插入图片描述

在这里插入图片描述

1.计算预测中心点与真实中心点的损失。
2.计算预测的宽高与真实宽高的损失。

用根号,是使得小框对误差更敏感。

第三项负责计算置信度的误差

标签值是预测框真实框的IOU,作为标签值。

第四项是不负责检测目标的框,让它们的Loss值越小越好。让他们的权重小一些,因为他们比较多。

第五项:负责检测物体那个框的分类误差。比如真实框类别标注是狗,那么预测的类别是狗的概率让它越来越接近1。

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import warnings

warnings.filterwarnings('ignore')  # 忽略警告消息
CLASS_NUM = 20    # (使用自己的数据集时需要更改)

class yoloLoss(nn.Module):
    def __init__(self, S, B, l_coord, l_noobj):
        # 一般而言 l_coord = 5 , l_noobj = 0.5
        super(yoloLoss, self).__init__()
        self.S = S  # S = 7
        self.B = B  # B = 2
        self.l_coord = l_coord
        self.l_noobj = l_noobj

    def compute_iou(self, box1, box2):  # box1(2,4)  box2(1,4)
        N = box1.size(0)  # 2
        M = box2.size(0)  # 1

        lt = torch.max(  # 返回张量所有元素的最大值
            # [N,2] -> [N,1,2] -> [N,M,2]
            box1[:, :2].unsqueeze(1).expand(N, M, 2),
            # [M,2] -> [1,M,2] -> [N,M,2]
            box2[:, :2].unsqueeze(0).expand(N, M, 2),
        )

        rb = torch.min(
            # [N,2] -> [N,1,2] -> [N,M,2]
            box1[:, 2:].unsqueeze(1).expand(N, M, 2),
            # [M,2] -> [1,M,2] -> [N,M,2]
            box2[:, 2:].unsqueeze(0).expand(N, M, 2),
        )

        wh = rb - lt  # [N,M,2]
        wh[wh < 0] = 0  # clip at 0
        inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]  重复面积

        area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])  # [N,]
        area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])  # [M,]
        area1 = area1.unsqueeze(1).expand_as(inter)  # [N,] -> [N,1] -> [N,M]
        area2 = area2.unsqueeze(0).expand_as(inter)  # [M,] -> [1,M] -> [N,M]

        iou = inter / (area1 + area2 - inter)
        return iou  # [2,1]

    def forward(self, pred_tensor, target_tensor):
        '''
        pred_tensor: (tensor) size(batchsize,7,7,30)
        target_tensor: (tensor) size(batchsize,7,7,30) --- ground truth
        '''
        N = pred_tensor.size()[0]  # batchsize
        coo_mask = target_tensor[:, :, :, 4] > 0  # 具有目标标签的索引值 true batchsize*7*7
        noo_mask = target_tensor[:, :, :, 4] == 0  # 不具有目标的标签索引值 false batchsize*7*7
        coo_mask = coo_mask.unsqueeze(-1).expand_as(target_tensor)  # 得到含物体的坐标等信息,复制粘贴 batchsize*7*7*30
        noo_mask = noo_mask.unsqueeze(-1).expand_as(target_tensor)  # 得到不含物体的坐标等信息 batchsize*7*7*30

        coo_pred = pred_tensor[coo_mask].view(-1, int(CLASS_NUM + 10))  # view类似于reshape
        box_pred = coo_pred[:, :10].contiguous().view(-1, 5)  # 塑造成X行5列(-1表示自动计算),一个box包含5个值
        class_pred = coo_pred[:, 10:]  # [n_coord, 20]

        coo_target = target_tensor[coo_mask].view(-1, int(CLASS_NUM + 10))
        box_target = coo_target[:, :10].contiguous().view(-1, 5)
        class_target = coo_target[:, 10:]

        # 不包含物体grid ceil的置信度损失
        noo_pred = pred_tensor[noo_mask].view(-1, int(CLASS_NUM + 10))
        noo_target = target_tensor[noo_mask].view(-1, int(CLASS_NUM + 10))
        noo_pred_mask = torch.cuda.ByteTensor(noo_pred.size()).bool()
        noo_pred_mask.zero_()
        noo_pred_mask[:, 4] = 1
        noo_pred_mask[:, 9] = 1
        noo_pred_c = noo_pred[noo_pred_mask]  # noo pred只需要计算 c 的损失 size[-1,2]
        noo_target_c = noo_target[noo_pred_mask]
        nooobj_loss = F.mse_loss(noo_pred_c, noo_target_c, size_average=False)  # 均方误差

        # compute contain obj loss
        coo_response_mask = torch.cuda.ByteTensor(box_target.size()).bool()  # ByteTensor 构建Byte类型的tensor元素全为0
        coo_response_mask.zero_()  # 全部元素置False                            bool:将其元素转变为布尔值

        no_coo_response_mask = torch.cuda.ByteTensor(box_target.size()).bool()  # ByteTensor 构建Byte类型的tensor元素全为0
        no_coo_response_mask.zero_()  # 全部元素置False                            bool:将其元素转变为布尔值

        box_target_iou = torch.zeros(box_target.size()).cuda()

        # box1 = 预测框  box2 = ground truth
        for i in range(0, box_target.size()[0], 2):  # box_target.size()[0]:有多少bbox,并且一次取两个bbox
            box1 = box_pred[i:i + 2]  # 第一个grid ceil对应的两个bbox
            box1_xyxy = Variable(torch.FloatTensor(box1.size()))
            box1_xyxy[:, :2] = box1[:, :2] / float(self.S) - 0.5 * box1[:, 2:4]  # 原本(xc,yc)为7*7 所以要除以7
            box1_xyxy[:, 2:4] = box1[:, :2] / float(self.S) + 0.5 * box1[:, 2:4]
            box2 = box_target[i].view(-1, 5)
            box2_xyxy = Variable(torch.FloatTensor(box2.size()))
            box2_xyxy[:, :2] = box2[:, :2] / float(self.S) - 0.5 * box2[:, 2:4]
            box2_xyxy[:, 2:4] = box2[:, :2] / float(self.S) + 0.5 * box2[:, 2:4]
            iou = self.compute_iou(box1_xyxy[:, :4], box2_xyxy[:, :4])
            max_iou, max_index = iou.max(0)
            max_index = max_index.data.cuda()
            coo_response_mask[i + max_index] = 1  # IOU最大的bbox
            no_coo_response_mask[i + 1 - max_index] = 1  # 舍去的bbox
            # confidence score = predicted box 与 the ground truth 的 IOU
            box_target_iou[i + max_index, torch.LongTensor([4]).cuda()] = max_iou.data.cuda()

        box_target_iou = Variable(box_target_iou).cuda()
        # 置信度误差(含物体的grid ceil的两个bbox与ground truth的IOU较大的一方)
        box_pred_response = box_pred[coo_response_mask].view(-1, 5)
        box_target_response_iou = box_target_iou[coo_response_mask].view(-1, 5)
        # IOU较小的一方
        no_box_pred_response = box_pred[no_coo_response_mask].view(-1, 5)
        no_box_target_response_iou = box_target_iou[no_coo_response_mask].view(-1, 5)
        no_box_target_response_iou[:, 4] = 0  # 保险起见置0(其实原本就是0)

        box_target_response = box_target[coo_response_mask].view(-1, 5)

        # 含物体grid ceil中IOU较大的bbox置信度损失
        contain_loss = F.mse_loss(box_pred_response[:, 4], box_target_response_iou[:, 4], size_average=False)
        # 含物体grid ceil中舍去的bbox损失
        no_contain_loss = F.mse_loss(no_box_pred_response[:, 4], no_box_target_response_iou[:, 4], size_average=False)
        # bbox坐标损失
        loc_loss = F.mse_loss(box_pred_response[:, :2], box_target_response[:, :2], size_average=False) + F.mse_loss(
            torch.sqrt(box_pred_response[:, 2:4]), torch.sqrt(box_target_response[:, 2:4]), size_average=False)

        # 类别损失
        class_loss = F.mse_loss(class_pred, class_target, size_average=False)

        return (self.l_coord * loc_loss + contain_loss + self.l_noobj * (nooobj_loss + no_contain_loss) + class_loss) / N

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1040522.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java函数式接口(Consumer、Function、Predicate、Supplier)详解及代码示例

函数式接口 java.util.function : Consumer :消费型函数接口 void accept(T t) Function :函数型接口 R apply(T t) Predicate :判断型接口 boolean test(T t) Supplier :供给型接口 T get() Consumer - 消费型函数接口 该接口代表了一个接受一个参数并且不返回结果的操作。…

vivado乘法器IP核进行无符号与有符号数相乘问题的验证

本文验证乘法器IP核Multiplier进行无符号(unsigned)与有符号数(signed)相乘的正确性&#xff0c;其中也遇到了一些问题&#xff0c;做此记录。 配套工程&#xff1a;https://download.csdn.net/download/weixin_48412658/88354179 文章目录 问题的讨论验证过程IP核配置例化乘…

工业AI视觉检测优势显著,深眸科技为工业自动化发展注入更强动力

随着工业自动化的不断发展&#xff0c;工业机器视觉检测技术日趋成熟&#xff0c;能够对制造生产线上的产品进行识别、定位、检测、测量等功能&#xff0c;使得工业生产更加高效和精准。 同时机器视觉检测也是一种基于图像处理和模式识别的技术&#xff0c;能够通过高清晰度工…

00-MySQL数据库的使用-下

一 多表查询 多表查询简介 笛卡尔乘积 笛卡尔乘积 &#xff1a; 当一个连接条件无效或被遗漏时&#xff0c;其结果是一个笛卡尔乘积 (Cartesian product)&#xff0c;其中所有行的组合都被显示。第一个表中的所 有行连接到第二个表中的所有行。一个笛卡尔乘积会产生大量的 行…

麦肯锡:中国生成式AI市场现状和未来发展趋势

本文来自《麦肯锡中国金融业CEO季刊》&#xff0c;版权归麦肯锡所有。该季刊主要围绕生成式AI&#xff08;以下简称“GenAI”&#xff09;主题&#xff0c;通过4大章节共8篇文章&#xff0c;全面深入分析了GenAI对各主要行业的影响、价值链投资机会、中国GenAI市场现状和未来趋…

YTM32的LINFlexD实现UART功能详解

文章目录 引言简介原理与机制同UART模式相关的寄存器时钟与波特率数据缓冲区发送过程接收过程 软件参考文献 引言 初看YTM32B1ME的手册时&#xff0c;一眼看上去&#xff0c;竟然没有找到UART模块的章节&#xff0c;心想这车规MCU的产品定义也太激进了&#xff0c;直接把工业和…

阿里云服务器使用教程(从购买到配置再到搭建自己的网站)

阿里云服务器使用教程包括云服务器购买、云服务器配置选择、云服务器开通端口号、搭建网站所需Web环境、安装网站程序、域名解析到云服务器公网IP地址&#xff0c;最后网站上线全流程&#xff0c;阿小云分享阿里云服务器详细使用教程&#xff1a; 目录 阿里云服务器使用教程 …

Python爬虫爬取豆瓣电影短评(爬虫入门,Scrapy框架,Xpath解析网站,jieba分词)

声明&#xff1a;以下内容仅供学习参考&#xff0c;禁止用于任何商业用途 很久之前就想学爬虫了&#xff0c;但是一直没机会&#xff0c;这次终于有机会了 主要参考了《疯狂python讲义》的最后一章 首先安装Scrapy&#xff1a; pip install scrapy 然后创建爬虫项目&#…

广东MES系统实现设备管理的方法与功能

在生产车间中&#xff0c;可以借助MES系统来完成设备管理。下面来看看借助MES系统实现设备管理比较常见的具体方法与功能&#xff1a; 1.在线监控和数据采集&#xff1a;MES系统能够与车间设备相连接&#xff0c;在线实时监控设备的运行状态和运行指标。凭借传感器、物联网产品…

【教程】Ubuntu自动查看有哪些用户名与密码相同的账户,并统一修改密码

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 目录 背景说明 开始操作 修改密码 背景说明 有些用户为了图方便或者初始创建用户默认设置等原因&#xff0c;会将密码设置为与用户名相同&#xff0c;但这就使得非常不安全。甚至如果该用户具有sudo权限&#…

力扣26:删除有序数组中的重复项

26. 删除有序数组中的重复项 - 力扣&#xff08;LeetCode&#xff09; 题目&#xff1a; 给你一个 非严格递增排列 的数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使每个元素 只出现一次 &#xff0c;返回删除后数组的新长度。元素的 相对顺序 应该保持 …

virtualbox安装的linux虚拟机安装并启动Tomcat过程(结合idea操作)记录,并使用宿主机访问页面

virtualbox安装的linux虚拟机安装并启动Tomcat过程&#xff08;结合idea操作&#xff09;记录&#xff0c;并使用宿主机访问页面 参考教程地址linux版本Tomcat下载地址上传解压 启动TomcatVirtualBox虚拟机本地可访问宿主机尚未可以访问关闭防火墙宿主机可以访问 参考教程地址 …

构建自己的物料解决方案——构建物料库,实现前端设计

01: 数据拦截简化数据获取流程 /** * 响应拦截器&#xff1a; * 服务端返回数据之后&#xff0c;前端 .then 之前被调用 */ service.interceptors.response.use((response) > {const { success, message, data } response.dataif (success) {return data}// TODO&#xff…

顺序读写函数的介绍:fscanf fprintf

目录 函数介绍&#xff1a; fprintf&#xff1a; 将结构体变量s的成员列表内容写入文件中&#xff1a; 文件效果&#xff1a;已经进行了格式化&#xff0c;3.140000是最明显的效果&#xff0c;因为float需要补齐0来补充精度 和printf的对比&#xff1a; 不同之处&#xff…

基础排序算法

插入排序&#xff08;insertion sort&#xff09; 插入排序每次循环将一个元素放置在适当的位置。像抓牌一样。手里的排是有序的&#xff0c;新拿一张牌&#xff0c;与手里的牌进行比较将其放在合适的位置。 插入排序要将待排序的数据分成两部分&#xff0c;一部分有序&#…

ddd扬帆

简介 读取了一篇“产品代码都给你看了&#xff0c;可别再说不会DDD”较为清晰了解了&#xff1a;领域驱动设计&#xff0c;整洁架构和事件驱动架构的架构思想落地实践&#xff0c;特做记录读后感&#xff0c;可以直接跳到正文阅读原文 正文 以下正文... 构建自己的软件大厦 …

【C++】C++11——可变参数模板和emplace

可变参数模板的定义方式可变参数模板的传值计算可变参数模板参数个数参数包展开方式递归展开参数包逗号表达式展开参数包 emplace插入 可变参数模板是C11新增的最强大的特性之一&#xff0c;它对参数高度泛化&#xff0c;能够让我们创建可以接受可变参数的函数模板和类模板。 在…

(a)Spring注解式开发,注册组件的@Repository,@Service,@Controller,@Component使用及说明

注解扫描原理 通过反射机制获取注解 Target(value {ElementType.TYPE})// 设置Component注解可以出现的位置&#xff0c;以上代表表示Component注解只能用在类和接口上 Retention(value RetentionPolicy.RUNTIME)// 设置Component注解的保持性策略&#xff0c;以上代表Comp…

社区团购美团和多多买菜小程序购物车

概述 微信小程序购物车列表demo 详细 需求 显示食物名称、价格、数量。 点击相应商品增加按钮,购买数量增加1,点击食物减少按钮,购买数量减一 显示购买总数和总金额 查看当前购买的商品 效果图(数据来自本地模拟) 目录结构 实现过程 主要wxml <view classfoods>…

工具篇 | H2数据库的使用和入门

引言 1.1 H2数据库概述 1.1.1 定义和特点 H2数据库是一款以 Java编写的轻量级关系型数据库。由于其小巧、灵活并且易于集成&#xff0c;H2经常被用作开发和测试环境中的便利数据库解决方案。除此之外&#xff0c;H2也适合作为生产环境中的嵌入式数据库。它不仅支持标准的SQL…