语义分割性能提升---通过优化损失改进分割效果

news2024/12/23 5:12:17

本文主要总结最近的调研调试结果,介绍通过改进损失来提升语义分割的分割效果;当然还有其他途径,比如蒸馏(提升分割效果)、剪枝(提升fps),之前博客有总结,此处不做介绍。

一、Dice损失

语义分割一般绕不开Dice损失,公式如下:

DiceLoss=1- 2X∩Y+SmoothX+Y+Smooth

集合相似度度量函数,

        通常用于计算两个样本的相似度,属于metric learning。X为真实目标mask,Y为预测目标mask,我们总是希望X和Y交集尽可能大,占比尽可能大,但是loss需要逐渐变小,所以在比值前面添加负号。
        另外,该损失可以缓解样本中前景背景(面积)不平衡带来的消极影响,前景背景不平衡也就是说图像中大部分区域是不包含目标的,只有一小部分区域包含目标。Dice Loss训练更关注对前景区域的挖掘,即保证有较低的FN,但会存在损失饱和问题,而CE Loss是平等地计算每个像素点的损失。因此单独使用Dice Loss往往并不能取得较好的结果,需要进行组合使用,比如Dice Loss+CELoss或者Dice Loss+Focal Loss等。

如下为Dice_Loss的torch实现:

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss

二、Tversky Loss

公式为:

T(X,Y)=X∩YX∩YX-Y+β|Y-X|

其中,|X|表示预测的分割图像,|Y|表示标签的分割图像。Tversky系数是Dice系数和 Jaccard 系数(就是IOU系数)的广义系数。

        如果分割任务更加关注召回率(高灵敏度),即真实mask尽可能都被预测出来,不太关注预测mask有没有多预测。Y为真实mask,X为预测mask。 其中α和β可以影响召回率和准确率,若想目标有较高的召回率,那么我们可以选择较高的beta;相反,若想有较高的准确率,则可选择较高的α。

        当设置α=β=0.5,此时Tversky系数就是Dice系数。而当设置α=β=1时,此时Tversky系数就是Jaccard系数。其中|X-Y|表示FP(假阳性),|Y-X|表示FN(假阴性),α,β分别控制假阴性和假阳性。通过调整 α 和 β 这两个超参数可以控制这两者之间的权衡,进而影响召回率等指标。

如下为Tiversky代码:

只用了一个参数alpha,另一个参数通过1-alpha来控制。

class tversky(nn.Module):
    def __init__(self, smooth=1):
        super(tversky, self).__init__()
        self.smooth = smooth


    def forward(self, logits, label,alpha=0.7):
        '''
        args: logits: tensor of shape (1, H, W)
        args: label: tensor of shape (1, H, W)
        '''
        probs = torch.sigmoid(logits)
        # print("logits:",probs.shape)
        # print("label:",label.shape)
        true_pos = torch.sum(label * probs)
        false_neg = torch.sum(label * (1 - probs))
        false_pos = torch.sum((1 - label) * probs)
        loss = (true_pos + self.smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + self.smooth)

        return 1-loss

三、boundary_loss边界损失

作者实验表明,boundary_loss结合Dice_Loss效果非常好,一个利用距离,一个利用边界。作者对这两个loss的用法是给他们一个权重,训练初期dice loss很高,随着训练进行,Boundary loss比例增加,也就是说越到训练后期越关注边界的准确,边界处理得更细一些。

大致原理为将distance map当做权重来作为某类loss的权重系数训练,详细公式解释请看以下链接:

https://zhuanlan.zhihu.com/p/72783363

源码中,核心部分的distance map变换代码如下:

def one_hot2dist(seg):
    res = np.zeros_like(seg)
    for i in range(len(seg)):
        posmask = seg[i].astype(np.bool)
        if posmask.any():
            negmask = ~posmask
            # print('negmask:', negmask)
            # print('distance(negmask):', distance(negmask))
            res[i] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask
            # print('res[c]', res[c])
    return res

假设某一类别的mask如下:

则得到的distance map为:

可以看到,边界处的权重为0,mask内部为负值,背景区域离边界越远权重值越大。

根据上面的distance map的可视化结果看出,如果预测边界小于或者完全符合真实边界并被真实边界包围,这时候loss为负。根据实测,一般训练到最后,Boundary loss会为负值。

如下为我调通的boundary_loss的GitHub链接:

active-boundary-loss/abl.py at main · wangchi95/active-boundary-loss · GitHubOfficial repository for Active Boundary Loss for Semantic Segmentation. - active-boundary-loss/abl.py at main · wangchi95/active-boundary-lossicon-default.png?t=O83Ahttps://github.com/wangchi95/active-boundary-loss/blob/main/abl.py切记:传入的pred和gt形状,需满足如下要求:

        Input:
            - pred: the output from model (before softmax)
                    shape (N, C, H, W)
            - gt: ground truth map
                    shape (N, H, w)

若不满足,需要进行转换。

如下为我自己的调用代码:

            Boundary = True
            if Boundary:
                
                # 我的输入为:outputs.shape=(n,c,h,w);labels.shape=(n,h,w,c)
                n, c, h, w = outputs.size()
                nt, ht, wt, ct = labels.size()

                # 进行维度转换
                temp_labels = torch.flatten(labels[...,:-1].transpose(2, 3).transpose(1, 2), 0, 1)  # labels.shape: n h w 4 --> n*3 h w
                # print('temp_labels.shape:',temp_labels.shape) 
                
                 # 调用
                abl = ABL()
                main_boundaryloss = abl(outputs,temp_labels)
                
                # loss为原损失(dice_loss),在此处也可以加一个系数,来控制dice_loss和边界损失的占比
                loss = loss + main_boundaryloss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2167921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数字化AI直播系统领创者:赋能无人直播新动力,永久告别假AI!

数字化AI直播系统领创者:赋能无人直播新动力,永久告别假AI! 在数字化浪潮汹涌的今天,AI技术正以前所未有的速度渗透并重塑着各行各业,而直播行业作为数字经济的重要组成部分,更是迎来了AI技术深度融合的崭…

C++类和对象第一关

一:类的定义 (一)类的定义 (1)类的定义格式: class name{ // 类成员变量 // 类方法(函数) }; class是定义类的关键字,name为定义的类的名字,后面的花括号…

【高中数学/对数函数/零点】已知函数f(x)=1/x-log(2,x),在下列区间中,包含f(x)零点的区间是?

【题目】 已知函数f(x)1/x-log(2,x),在下列区间中,包含f(x)零点的区间是? A.(0,1) B.(2,3) C.(3,∞) D.(1,2) 【出处】 《高考数学极致解题大招》P136 第二题 中原教研工作室编著 【解答】 零点即01/x-log(2,x),移项得1/xlog(2,x) 两曲线y1/x…

【C++习题】2.双指针_移动零

文章目录 题目链接:题目描述:解法(快排的思想:数组划分区间 - 数组分两块):C 算法代码:图解 题目链接: 283.移动零 题目描述: 解法(快排的思想:数…

聚铭下一代智慧安全运营中心荣获CNNVD兼容性资质证书

近日,聚铭网络旗下安全产品——聚铭下一代智慧安全运营中心正式通过了国家信息安全漏洞库(CNNVD)兼容性认证测试,荣获国家信息安全漏洞库兼容性资质证书。 关于CNNVD兼容性 国家信息安全漏洞库(CNNVD)是…

Iceberg 基本操作和快速入门

安装 Iceberg 是一种适用于大型分析表的高性能工具,通过spark启动并运行iceberg,文章是通过docker来进行安装并测试的 新建一个docker-compose.yml文件 文件内容 version: "3" services: spark-iceberg: image: tabulario/spark-iceberg co…

干部画像如何精准科学识别优秀干部

干部画像作为一种精准、科学的评价工具,在识别优秀干部方面发挥着关键作用。通过全面、深入、系统地收集、整理和分析干部的多维度信息,形成一幅反映干部综合素质和能力的立体画卷,为组织部门提供了详实可靠的依据。以下是干部画像在精准、科…

Colorful/七彩虹将星X15 AT22 2022 Win11原厂OEM系统 带COLORFUL一键还原

安装完毕自带原厂驱动和预装软件以及一键恢复功能,自动重建COLORFUL RECOVERY功能,恢复到新机开箱状态。 【格式】:iso 【系统类型】:Windows11 原厂系统下载网址:http://www.bioxt.cn 注意:安装系统会…

Exception in thread “main“ java.lang.CloneNotSupportedException 解决方案

目录 前言: 解决方案 后言: 结言: 前言: 今天在学习设计模式的时候,犯的一个错误。很低级的错误,不过也记录一下(绝对不是想水文章)。 解决方案 在使用克隆方法时抛出这个异…

2024年第五届电力工程国际会议(ICPE 2024)将在上海召开!

为了总结交流我国电力研究技术的最新研究成果,促进国内外电力技术发展与交流,开拓电力技术应用领域,将于2024年12月13-15日在 中国上海举办第五届电力工程国际会议 (ICPE 2024) 。 本次会议由IEEE、PES、上海电力大学主办,电子科技…

猜想的反例:DFS中结点顺序与后代关系的分析

猜想的反例:DFS中结点顺序与后代关系的分析 猜想分析与反例构造反例描述伪代码与C代码实现反例验证在图论中,深度优先搜索(DFS)是一种重要的图遍历算法,它可以生成一棵深度优先森林(DFS Forest),揭示结点之间的祖先-后代关系。本文探讨一个特定猜想:如果有向图G包含一…

Linux服务器安装Anaconda环境

Linux浪潮云服务器安装Anaconda环境 读研之后在导师的帮助下,获得了浪潮的一台公共云服务器。以后做实验跑代码就可以使用云服务器上的虚拟环境了。减少了自己笔记本的压力。在创建并保存完成镜像环境之后。最重要的就是安装好深度学习需要的Anaconda环境&#xff0…

vue-i18n在使用$t时提示类型错误

1. 问题描述 Vue3项目中,使用vue-i18n,在模版中使用$t时,页面可以正常渲染,但是类型报错。 相关依赖版本如下: "dependencies": {"vue": "^3.4.29","vue-i18n": "^9.1…

MES管理系统的工单管理功能模块有什么用

在当今制造业的快速发展中,企业对于生产流程的高效管理与优化需求日益迫切。MES管理系统作为集成了生产计划、物料追踪、工艺执行、设备监控以及质量管理等核心功能的综合性软件平台,正逐步成为企业转型升级的关键驱动力。MES管理系统不仅通过实时数据洞…

鸿蒙界面开发——组件(9):进度条Progress 滑动条Slider

进度条 (Progress) Progress(options: {value: number, total?: number, type?: ProgressType})其中,value用于设置初始进度值,total用于设置进度总长度,type用于设置Progress样式。 Progress有5种可选类型,通过ProgressType可…

必应bing搜索广告如何开户?投放需要多少钱?

网络营销已成为企业增长不可或缺的一部分,为了帮助企业更高效地触达目标客户,云衔科技携手必应Bing搜索引擎,提供专业、便捷的广告开户与代运营服务。无论您是希望扩大品牌影响力,还是提升产品销量,选择云衔科技&#…

阻塞型IO与非阻塞型IO

阻塞IO与非阻塞IO 一.IO模型 IO的本质是基于操作系统接口来控制底层的硬件之间数据传输,并且在操作系统中实现了多种不同的IO方式(模型),比较常见的有下列三种 阻塞型IO模型非阻塞型IO模型多路复用IO模型(重点!重点!重点!) 二…

KVM 安装 Windows11

在 KVM 安装 Windows 比安装 Ubuntu 会复杂一些,去微软官网下载 Win11,同时要下载 Virtio 可以从 Fedora 下载 (https://fedorapeople.org/groups/virt/virtio-win/direct-downloads/)。 安装Window 命令行输入以下命令&#xf…

uniapp数据缓存

利用uniapp做开发时,缓存数据是及其重要的,下面是同步缓存和异步缓存的使用 同步缓存 在执行同步缓存时会阻塞其他代码的执行 ① uni.setStorageSync(key, data) 设置缓存,如: uni.setStorageSync(name, 张三) ② uni.getSt…

Python 课程21-Django

前言 在当今互联网时代,Web开发已成为一项必备技能。而Python作为一门简洁、高效的编程语言,其Web框架Django以其强大的功能和快速开发的特点,受到了广大开发者的青睐。如果你想深入学习Django,构建自己的Web应用,那么…