图像分割项目中损失函数的选择

news2025/1/18 13:49:56

文章目录

  • 前言
  • 场景:实际项目中,通常会有一个常见的问题:样本不均衡
  • 一、focal loss
    • 思考
  • 二、Dice loss
  • 三、二分类
  • 未完待续

前言

在图像分割领域,最基础、最常见的损失当然是交叉熵损失 —— Cross entropy。随着不断的研究,涌现出了许多优于交叉熵损失的,并且在实际场景中,也往往不会在单单使用交叉熵损失了。

场景:实际项目中,通常会有一个常见的问题:样本不均衡

一、focal loss

focal loss从样本难易分类角度出发,解决样本非平衡带来的模型训练问题。
  通常情况下,样本不均衡所带来的问题是少样本难以区分(当然也会存在一些本身就很难区分或分割的样本),因此focal loss聚焦于难分样本,在梯度求导时,让难分类样本占主导,因此训练学习过程更加聚焦在难分样本。

思考

   focal loss在训练过程中本身是一个动态选择,并不稳定,这也是为什么有些情形下使用focal loss还不如原本的CE loss。通常来说,为了防止难易样本的频繁变化,应当选取小的学习率

代码如下(示例):

class FocalLoss(nn.Module):
    """
    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
    :param num_class:
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
    """

    def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-1, size_average=True):
        super(FocalLoss, self).__init__()
        self.apply_nonlin = apply_nonlin
        self.alpha = alpha
        self.gamma = gamma
        self.balance_index = balance_index
        self.smooth = smooth
        self.size_average = size_average

        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, logit, target):
        N=logit.shape[1]
        self.alpha = enet_weighing(target, N).cuda()

        logit = F.softmax(logit, dim=1)
        if self.apply_nonlin is not None:
            logit = self.apply_nonlin(logit)
        num_class = logit.shape[1]
        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            logit = logit.permute(0, 2, 1).contiguous()
            logit = logit.view(-1, logit.size(-1))
        target = torch.squeeze(target, 1)
        target = target.view(-1, 1)
        # print(logit.shape, target.shape)
        #
        alpha = self.alpha

        if alpha is None:
            alpha = torch.ones(num_class, 1)
        elif isinstance(alpha, (list, np.ndarray)):
            assert len(alpha) == num_class
            alpha = torch.FloatTensor(alpha).view(num_class, 1)
            alpha = alpha / alpha.sum()
        elif isinstance(alpha, float):
            alpha = torch.ones(num_class, 1)
            alpha = alpha * (1 - self.alpha)
            alpha[self.balance_index] = self.alpha

        # else:
        #     raise TypeError('Not support alpha type')

        if alpha.device != logit.device:
            alpha = alpha.to(logit.device)

        idx = target.cpu().long()

        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        if one_hot_key.device != logit.device:
            one_hot_key = one_hot_key.to(logit.device)

        if self.smooth:
            one_hot_key = torch.clamp(
                one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
        pt = (one_hot_key * logit).sum(1) + self.smooth
        logpt = pt.log()

        gamma = self.gamma

        alpha = alpha[idx]
        alpha = torch.squeeze(alpha)
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
        return loss

# 训练过程
focal = FocalLoss()
FocalLoss1 = focal(out, label) # out:模型输出  label:标签

二、Dice loss

Dice loss适用于样本极度不平衡的情况,一般情况下使用Dice Loss会对反向传播不利,使得训练不稳定。因为,通常是将Dice loss作为辅助损失函数来和主损失函数一起训练,如Dice loss+CE loss 或 Dice loss + Focal loss

代码如下(示例):

import torch
from torch import Tensor
import torch.nn.functional as F

def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
    assert input.size() == target.size()
    if input.dim() == 2 and reduce_batch_first:
        raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')

    if input.dim() == 2 or reduce_batch_first:
        inter = torch.dot(input.reshape(-1), target.reshape(-1))
        sets_sum = torch.sum(input) + torch.sum(target)
        if sets_sum.item() == 0:
            sets_sum = 2 * inter

        return (2 * inter + epsilon) / (sets_sum + epsilon)
    else:
        # compute and average metric for each batch element
        dice = 0
        for i in range(input.shape[0]):
            dice += dice_coeff(input[i, ...], target[i, ...])
        return dice / input.shape[0]


def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
    # Average of Dice coefficient for all classes
    assert input.size() == target.size()
    dice = 0
    for channel in range(input.shape[1]):
        dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)

    return dice / input.shape[1]


def dice_loss(input: Tensor, target: Tensor, multiclass: bool = True):
    # Dice loss (objective to minimize) between 0 and 1
    assert input.size() == target.size()
    fn = multiclass_dice_coeff if multiclass else dice_coeff
    return 1 - fn(input, target, reduce_batch_first=True)

# 训练过程
lossp = dice_loss(F.softmax(out, dim=1).float(),
                 F.one_hot(lb, n_classes).permute(0, 3,1,2).contiguous().float(),  multiclass=True)

三、二分类

图像分割二分类任务一般有两种方式:
(1)和多分类任务一样,只是最后的输出通道num_class设置为2,所以输出的是一个二通道图。二分类标签label是一个单通道图,数值只有0和1两者。为了让模型的输出图不断逼近于abel,会让输出图先经过一个softmax函数,使其数值归一化到(0,1)之间,即让同一位置上两个通道的值加起来等于1。而对于label,会使用onehot编码,转换成了 num_class=2 个通道的图像。然后就可以让输出图和label进行对应的损失计算了。大致流程如下图所示:
在这里插入图片描述
注:

1)二分类任务,经过softmax后,是同一位置的两个通道值之和为1,若是多分类任务,也就是多个通道之和为1。

2)二分类label经过one-hot编码,0变为[0,1],1变为[1,0];若是多分类任务,假设为4分类,那label图里就是 [0,1,2,3] 这四个像素值。则one-hot编码如下:
0 —— 【0,0,0,1】
1 —— 【0,0,1,0】
2 —— 【0,1,0,0】
3 —— 【1,0,0,0】

3)对于CrossEntropyLoss和FocalLoss,其函数内部自带有处理方式,所以无需改动,直接将输出图和label传进去即可,如上面代码:

focal = FocalLoss()
FocalLoss1 = focal(out, label) # out:模型输出  label:标签

loss = torch.nn.CrossEntropyLoss()
loss = loss(out, label)

对于Dice loss,需要自己改动输入方式,如上面代码:

lossp = dice_loss(F.softmax(out, dim=1).float(),
 F.one_hot(lb, n_classes).permute(0, 3, 1, 2).contiguous().float(), multiclass=True)

(2)第二种方式,是显著性目标检测任务中常用的,只输出单通道,即num_class=1。这时是使用sigmoid函数来对输出图进行归一化到(0,1)之间,由于输出图和label都是单通道图,所以可以直接计算损失。可以参考显著性目标检测论文中常用的损失函数:BCE + IOU (BCE关注像素,IOU关注整体结构,两者一起用其实相当于 CE+Dice)

注:使用torch.nn.BCELoss(),需要自己对输出图使用sigmoid处理;若使用BCEWithLogitsLoss(),其函数内部有sigmoid处理,就不需要自己加了。

未完待续

持续记录以后项目中用到的损失函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1187357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一个java文件的JVM之旅

准备 我是小C同学编写得一个java文件&#xff0c;如何实现我的功能呢&#xff1f;需要去JVM(Java Virtual Machine)这个地方旅行。 变身 我高高兴兴的来到JVM&#xff0c;想要开始JVM之旅&#xff0c;它确说&#xff1a;“现在的我还不能进去&#xff0c;需要做一次转换&#x…

OceanBase 如何通过日志观测冻结转储流程?

本文旨在通过日志解析 OceanBase 的冻结转储流程&#xff0c;以其冻结检查线程为切入点&#xff0c;以租户&#xff08;1002&#xff09;的线程名为例。 作者&#xff1a;陈慧明&#xff0c;爱可生测试工程师&#xff0c;主要参与 DMP 和 DBLE 自动化测试项目。 爱可生开源社区…

2258. 逃离火灾 : 详解如何从「二分」到「分类讨论」(图解过程)

题目描述 这是 LeetCode 上的 「2258. 逃离火灾」 &#xff0c;难度为 「困难」。 Tag : 「多源 BFS」、「二分」、「预处理」 给你一个下标从 0 开始大小为 m x n 的二维整数数组 grid&#xff0c;它表示一个网格图。 每个格子为下面 个值之一&#xff1a; 0 表示草地。 1 表…

【狂神说Java】SpringSecurity+shiro

✅作者简介&#xff1a;CSDN内容合伙人、信息安全专业在校大学生&#x1f3c6; &#x1f525;系列专栏 &#xff1a;【狂神说Java】 &#x1f4c3;新人博主 &#xff1a;欢迎点赞收藏关注&#xff0c;会回访&#xff01; &#x1f4ac;舞台再大&#xff0c;你不上台&#xff0c…

ShuffleNet系列 网络结构

文章目录 ShuffleNet V1Channel Shuffle&#xff1a;通道打散SuffleNet UnitModel Architecture实验结果 ShuffleNet V2Guideline 1Guideline 2Guideline 3Guideline 4模型结构代码 论文&#xff1a;ShuffleNet: ShuffleNet: An Extremely Efficient Convolutional Neural Netw…

2023年的低代码:数字化、人工智能、趋势及未来展望

本文由葡萄城技术团队发布。转载请注明出处&#xff1a;葡萄城官网&#xff0c;葡萄城为开发者提供专业的开发工具、解决方案和服务&#xff0c;赋能开发者。 前言 正如许多专家预测的那样&#xff0c;低代码平台在2023年将展现更加强劲的势头。越来越多的企业正在纷纷转向低代…

ArcGIS 气象风场等示例 数据制作、服务发布及前端加载

1. 原始数据为多维数据 以nc数据为例。 首先在pro中需要以多维数据的方式去添加多维数据&#xff0c;这里的数据包含uv方向&#xff1a; 加载进pro的效果&#xff1a; 这里注意 数据属性需要为矢量uv&#xff1a; 如果要发布为服务&#xff0c;需要导出存储为tif格式&…

spring 中 @Validated/@Valid

超级好的链接 添加链接描述

Vue实现面经基础版案例(路由+组件缓存)

一、面经基础版-案例效果分析 1.面经效果演示 2.功能分析 通过演示效果发现&#xff0c;主要的功能页面有两个&#xff0c;一个是列表页&#xff0c;一个是详情页&#xff0c;并且在列表页点击时可以跳转到详情页底部导航可以来回切换&#xff0c;并且切换时&#xff0c;只有…

AI:69-基于深度学习的音乐推荐

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

JavaScript脚本操作CSS

脚本化CSS就是使用JavaScript脚本操作CSS&#xff0c;配合HTML5、Ajax、jQuery等技术&#xff0c;可以设计出细腻、逼真的页面特效和交互行为&#xff0c;提升用户体验&#xff0c;如网页对象的显示/隐藏、定位、变形、运动等动态样式。 1、CSS脚本化基础 CSS样式有两种形式&…

OpenCV 在ImShow窗体上选择感兴趣的区域

窗体上选择感兴趣ROI区域 在计算机视觉处理中, 通常是针对图像中的一个特定区域进行处理, 有时候这个特定区域需要人来选择, OpenCV 也提供了窗口选择ROI机制. 窗体支持两种选择ROI区域的方法, 一个是单选, 一个是多选, 操作方法如下: 单选: 通过鼠标在屏幕上选择区域, 然后通过…

【Linux系统编程十六】:(基础IO3)--用户级缓冲区

【Linux系统编程十六】&#xff1a;基础IO3--用户级缓冲区 一.用户级缓冲区二.缓冲区刷新策略1.验证&#xff1a; 三.缓冲区意义 一.用户级缓冲区 我们首先理解上面的代码&#xff0c;分别使用printf和fprintf&#xff0c;fwrite往1号文件描述符里输出&#xff0c;也就是往显示…

论文阅读——InternImage(cvpr2023)

arxiv&#xff1a;https://arxiv.org/abs/2211.05778 github&#xff1a;https://github.com/OpenGVLab/InternImage 一、介绍 大部分大模型都是基于transformer的&#xff0c;本文是一个基于CNN的视觉基础模型。使用可变性卷积deformable convolution作为核心操作&…

docker复制镜像文件

一、复制镜像 #1. 查找本机已有的镜像docker images |grep xxxx#2. 将镜像复制出来指向到xxxx.tar的文件中 docker save 343cca04e31d > xxxx.tareg: 二、加载镜像 直接将拷贝好的镜像包直接加载即可 docker load < myimage.tar

【C++】一文简练总结【多态】及其底层原理&具体应用(21)

前言 大家好吖&#xff0c;欢迎来到 YY 滴C系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过C的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; 目录 一.多态的概念二.多态的实现1&#xff…

Codeforces Round 908 (Div. 2)视频详解

Educational Codeforces Round 157 &#xff08;A--D&#xff09;视频详解 视频链接A题代码B题代码C题代码D题代码 视频链接 Codeforces Round 908 (Div. 2)视频详解 A题代码 #include<bits/stdc.h> #define endl \n #define deb(x) cout << #x << "…

mac M2 anaconda 解决装不了python3.7

今天发现一个很奇怪的问题 但是我一换成 conda create -n DCA python3.8.12就是成功的 这个就很奇怪, 解决如下 https://towardsdatascience.com/how-to-manage-conda-environments-on-an-apple-silicon-m1-mac-1e29cb3bad12 998 conda search pythonconda search python …

C++之函数中实现类与调用总结(二百五十四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…