F-score 和 Dice Loss 原理及代码

news2024/11/19 21:22:06

文章目录

    • 1. F-score
    • 1. 1 原理
    • 1. 2 代码
    • 2. Dice Loss
      • 2.1 原理
      • 2.2 代码

通过看开源图像语义分割库的源码,发现它对 Dice Loss 的实现方式,是直接调用 F-score 函数,换言之,Dice LossF-score的特殊情况。于是就研究了一下这背后的原理,作文以记之。

1. F-score

1. 1 原理

首先介绍 F-score:
在这里插入图片描述
要理解F-score,就要先回顾一下 PrecisionRecall,首先给出公式:

在这里插入图片描述
两个指标衡量算法的准确性时,通常是相互排斥的。例如,输入一个数据,算法根据数据预测一个分数,现在为该分数设定阈值,大于阈值的预测为真,小于该阈值的预测为假。

  • 如果这个阈值得过低,低到测试集中所有的样本均判定为真,那么此时,FN=0(False negative, 压根就没有预测出来 negative 的样本),代入公式 (2) 得 Recall = 1。但此时,预测为真的样本中,包含大量的 FP,即 False Positive,将会导致 Precision 过低
  • 如果这个阈值设置得过高,使得所有被判定为正的样本都是真的,那么 FP=0,Precision=1,此时将不可避免有很多本应被判定为正的样本,被错误地判定为负,也就是 FN 很大,导致 Recall 过低

不同的应用场景下,对这两个指标的侧重不同。例如新冠感染者检测,就应该尽量提高 Recall,务求没有漏网之鱼。但在检测垃圾邮件时,应该尽量提升 Precision,即每个被判定为垃圾邮件的,都是板上钉钉毫无争议的,防止出现误伤,把正常邮件当成垃圾邮件处理。

F-score 则是将这两个指标综合起来:
在这里插入图片描述

  • β \beta β控制 Precision 和 Recall 的重要程度, 当 β = 1 \beta=1 β=1, 对应 F1-score,此时 Precision 和 Recall 同样重要。

  • β \beta β两个常用的取值是 0.5 2,当取 0.5 时,Precision 对 F-score 的影响更大,当取 2 时,Recall 对 F-score 的影响更大。(可以考虑得更极端一点,当 β → 0 \beta\rightarrow0 β0,公式(3)趋于 Precision;当 β → ∞ \beta\rightarrow\infty β,公式(3)上下同除以分子,易知其将趋于 Recall)

最后,把 (1) (2) 代入 (3) 得:
在这里插入图片描述

1. 2 代码

def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice系数
    #--------------------------------------------#
    temp_inputs = torch.gt(temp_inputs, threhold).float()
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    score = torch.mean(score)
    return score
  • inputs为分割模型的预测输出,未经过softmax, target为gt
  • temp_target中将channels维度设为num_classes+1,为了方便处理白边,因此在实际计算时需要去掉最后一个channel: temp_target[...,:-1]
    在这里插入图片描述
  • 预测分割图temp_inputs与 GT 分割图的点乘,然后再(n,hw)方向上求和作为tp
    参考自: Dice系数(Dice coefficient)与mIoU与Dice Loss
    在这里插入图片描述
  • 因为预测temp_inputs (pred) = fp+tp, 因此已知temp_inputstp, 就可以求出fp
  • 同理temp_target (gt) = fn+tp, 因此已知temp_targettp, 就可以求出`fn
  • 然后根据F-score的计算公式,在已知tp,fp,fn以及beta系数,就可以计算出F-score值了
    在这里插入图片描述
    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    score = torch.mean(score)

2. Dice Loss

2.1 原理

Dice Loss 是语义分割中常用的一种损失,它的计算方法如下:
在这里插入图片描述
即,分子为预测值与真实值的交集元素数目的两倍,分母为两个集合元素数目之和(注意并不是并集,而是和)。而
在这里插入图片描述
因此,(6) 相当于:
1 − 2 T P 2 T P + F P + F N 1-\frac{2TP}{2TP+FP+FN} 12TP+FP+FN2TP

而上式的结果,正是公式 (5) 中 β = 1 \beta =1 β=1的情况,也就是F1 score。因此,

Dice Loss = 1 - F1 score 

2.2 代码

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss
  • 可以看到dice_loss的实现,跟F-score基本上是一模一样的, 将torch.mean(score)求得的F-soce, 然后通过dice_loss = 1- F-score 来实现。
  • 代码中默认 β = 1 \beta=1 β=1, 所以更精确的说: dice_loss = 1- F1-score
  • DIce _loss的在训练损失中的使用如下:
    在这里插入图片描述

参考:

  • F-score 和 Dice Loss
  • https://github.com/bubbliiiing/deeplabv3-plus-pytorch/blob/main/utils/utils_metrics.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1384431.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网站漏洞扫描 awvs 23.11下载 Acunetix Premium build 23.11 for Linux 完美版

Acunetix Premium build 23.11 for Linux 完美版 更新日志: 网站漏洞扫描 awvs 23.11下载 新功能 Java IAST 传感器已更新为支持 Java 17 并删除了对 AspectJWeaver 的要求对管理适用于 Docker 和 Linux 的 Acunetix On-Premises 服务的机制进行了更改&#xff0…

前端js写数据结构与算法

1、什么是数据结构与算法 数据结构:是指数据对象中数据元素之间的相互关系。包括集合结构、线性结构、树形结构、图形结构。 算法:解决问题的思路。 2、时间复杂度 1.是什么? 执行当前算法所“花费的时间” 2.干什么? 在写代码的过程中&#xf…

C# .NET SQL sugar中 IsAny进行根据条件判断数据是否存在 IsAny的使用

SQL sugar 中控制器直接判断数据是否存在 首先确保你的Service层继承的表名 控制器中使用IsAny进行根据条件判断数据是否存在

算法通关村第十五关—继续研究超大规模数据场景的问题(黄金)

继续研究超大规模数据场景的问题 一、对20GB文件进行排序 题目要求:假设你有一个20GB的文件,每行一个字符串,请说明如何对这个文件进行排序?  分析:这里给出大小是20GB,其实面试官就在暗示你不要将所有的文件都装入到…

墙地砖外形检测的技术方案-图像获取

硬件系统 墙地砖外形检测硬件系统主要由工业相机、光源、瓷砖位置检测电路和上位机组成,其结构如图所示。为了提高系统检测精度和稳定性,系统采用的是较高精度的高速工业相机用于抓取墙地砖表面轮廓图像,图像数据通过USB接口向上位机传送&am…

Maven《一》-- 一文带你快速了解Maven

目录 🐶1.1 为什么使用Maven 1. Mavan是一个依赖管理工具 ①jar包的规模 ②jar包的来源问题 ③jar包的导入问题 ④jar包之间的依赖 2. Mavan是一个构建工具 ①你没有注意过的构建 ②脱离IDE环境仍需构建 3. 结论 🐶1.2 什么是Maven &#x…

系列四、Spring Security中的认证 授权(前后端不分离)

一、Spring Security中的认证 & 授权(前后端不分离) 1.1、MyWebSecurityConfigurerAdapter /*** Author : 一叶浮萍归大海* Date: 2024/1/11 21:50* Description:*/ Configuration public class MyWebSecurityConfigurerAdapter extends WebSecuri…

ZZULIOJ 1110: 最近共同祖先(函数专题)

题目描述 如上图所示,由正整数1, 2, 3, ...组成了一棵无限大的二叉树。从某一个结点到根结 点(编号是1 的结点)都有一条唯一的路径,比如从10 到根结点的路径是(10, 5, 2, 1), 从4 到根结点的路径是(4, 2, 1)&#xff0…

x-cmd pkg | qrencode - 二维码生成工具

目录 简介首次用户功能特点竞品和相关作品进一步阅读 简介 qrencode 是一个用于生成二维码的命令行工具。它可以将文本、URL、电话号码等信息转换为二维码图像。生成的二维码图像可以保存为图片文件,方便在电子文档、网页、移动应用等各种场景中使用。 它支持的二维…

python爬虫小练习——爬取豆瓣电影top250

爬取豆瓣电影top250 需求分析 将爬取的数据导入到表格中,方便人为查看。 实现方法 三大功能 1,下载所有网页内容。 2,处理网页中的内容提取自己想要的数据 3,导入到表格中 分析网站结构需要提取的内容 代码 import requests…

Random的使用

作用:生成伪随机数 1.导包:import java.util.Random 2.得到随机数对象:Random r new Random(); 3.调用随机数的功能获取随机数: 这里随机生成一个0-9的整数: int number r.nextInt(10); 实现指定区间的随机数&a…

C语言中关于指针的理解及用法

关于指针意思的参考:https://baike.baidu.com/item/%e6%8c%87%e9%92%88/2878304 指针 指针变量 地址 野指针 野指针就是指针指向的位置是不可知的(随机的,不正确的,没有明确限制的) 以下是导致野指针的原因 1.指针…

利益兑现期越短,积极性越高

在2023年一次部门项目提成时间节点的调整,引发了相关的销售部门 ,项目集成部门,软件开发部门截然不同的工作积极性。 公司案例 公司做项目的时候,采用的是相关部门都可以在项目获取提成 ,之前的提成方式为销售部门为…

maven镜像源设置aliyun提升下载速度

一、打开pom.xml project下在添加 <repositories><repository><id>aliyunmaven</id><name>aliyun</name><url>https://maven.aliyun.com/repository/public</url></repository><repository><id>central2&l…

【Arduino】编程语言:定时函数、数学函数、字符函数(功能、语法格式、参数说明、返回值) | 软件开发环境:安装步骤介绍(EXE安装版、ZIP安装版)

你的负担将变成礼物,你受的苦将照亮你的路。———泰戈尔 🎯作者主页: 追光者♂🔥 🌸个人简介: 💖[1] 计算机专业硕士研究生💖 🌿[2] 2023年城市之星领跑者TOP1(哈尔滨)🌿 🌟[3] 2022年度博客之星人工智能领域TOP4🌟 🏅[4] 阿里云社区…

prometheus常用exporter

一、node-exporter node_exporter&#xff1a;用于监控Linux系统的指标采集器。 未在k8s集群内的linux机器监控 GitHub - prometheus/node_exporter: Exporter for machine metrics 常用指标&#xff1a; •CPU • 内存 • 硬盘 • 网络流量 • 文件描述符 • 系统负载 •…

电子电器架构车载软件 —— 集中化架构软件开发

电子电器架构车载软件 —— 集中化架构软件开发 我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 屏蔽力是信息过载时代一个人的特殊竞争力&#xff0c;任…

vulnhub靶场之DC-8

一.环境搭建 1.靶场描述 DC-8 is another purposely built vulnerable lab with the intent of gaining experience in the world of penetration testing. This challenge is a bit of a hybrid between being an actual challenge, and being a "proof of concept&quo…

机器学习入门知识

一、引言 机器学习是当前信息技术中最令人振奋的领域之一。在这门课程中&#xff0c;我们将探索该技术的前沿&#xff0c;并能够亲自实现机器学习的算法。 或许你每天都在不知不觉中使用了机器学习的算法。每次你打开谷歌或必应搜索你需要的内容&#xff0c;正是因为它们拥有出…

如何使用vite框架封装一个js库,并发布npm包

目录 前言介绍 一、创建一个vite项目 1、使用创建命令&#xff1a; 2、选择others 3、 模板选择library 4、选择开发语言 ​编辑 二、安装依赖 三、目录介绍 1、vite.config.ts打包文件 2、package.json命令配置文件 三、发布npm 1、注册npm账号 2、设置npm源 3、登…