【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析

news2024/9/25 14:40:09

【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析

文章目录

  • 【图像分割】【深度学习】PFNet官方Pytorch代码-PFNet网络损失函数模块解析
  • 前言
  • PM定位模块损失函数
  • FM聚焦模块损失函数
  • 总结


前言

在详细解析PFNet代码之前,首要任务是成功运行PFNet代码【win10下参考教程】,后续学习才有意义。本博客讲解PFNet神经网络模块的损失函数模块代码,不涉及其他功能模块代码。

PFNet中有四个输出预测,一个来自定位模块(PM),三个来自聚焦模块(FM),整体的损失函数为:
ℓ o v e r a l l = ℓ p m + ∑ i = 1 3 2 ( 3 − i ) ℓ f m i {\ell _{overall}}{\rm{ }} = {\rm{ }}{\ell _{pm}} + \sum\limits_{i = 1}^3 {{2^{(3 - i)}}} \ell _{fm}^i overall=pm+i=132(3i)fmi
其中 ℓ f m i \ell _{fm}^i fmi表示在PFNet网络中至上往下第 i i i个的聚焦模块的预测的损失。

博主将各功能模块的代码在不同的博文中进行了详细的解析,点击【win10下参考教程】,博文的目录链接放在前言部分。


PM定位模块损失函数

对于PM模块,使用二值交叉熵损失(Binary CrossEntropy Loss,BCE)损失 ℓ b c e \ell _{{\rm{bce}}} bce和IoU损失 ℓ i o u \ell _{{\rm{iou}}} iou的输出,即 ℓ p m = ℓ b c e + ℓ i o u {\ell _{{\rm{pm}}}} = {\ell _{{\rm{bce}}}} + {\ell _{{\rm{iou}}}} pm=bce+iou,以引导PM探索目标对象的初始位置。
二值交叉熵损失 ℓ i o u \ell _{{\rm{iou}}} iou是常见用法,因此不再具体讲解,本小节主要介绍 ℓ i o u \ell _{{\rm{iou}}} iou,因为它不同于目标检测中用于衡量预测边界框与真实边界框之间的重叠程度,而在论文中对此并没有详细解释,因此博主根据论文源码绘制以下示意图具体讲解 ℓ i o u \ell _{{\rm{iou}}} iou的作用:

ℓ i o u = 1 − i o u {\ell _{{\rm{iou}}}} = 1 - iou iou=1iou i o u iou iou重合度越高, ℓ i o u \ell _{{\rm{iou}}} iou损失越小, i o u = i n t e r u n i o n − i n t e r iou = \frac{{{\rm{inter}}}}{{{\rm{union - inter}}}} iou=unioninterinter。那么 i n t e r inter inter u n i o n − i n t e r union - inter unioninter分别表示什么含义呢?博主将根据所绘制的示意图详细说明其中的含义,如上图所示, m a s k mask mask只有前景为1背景为0俩种值, p r e d pred pred的取值范围则在(0~1)之间,为了方便理解博主也是暴力的拆解成前景为0.8背景为0.2俩种值。

  1. i n t e r inter inter表示真实标签 m a s k mask mask和预测标签 p r e d pred pred对应像素相乘后再对像素值求和的值,如上图的inter所示(只表示到对应元素相乘), i n t e r inter inter的含义可以理解成真实标签的前景部分在预测标签上的预测结果,简单来说就是只考虑预测标签针对真实前景的预测效果,默认背景部分完全预测正确,屏蔽了背景不作考虑,因此 i n t e r = T b + P f inter=T_b+P_f inter=Tb+Pf
  2. u n i o n union union表示真实标签 m a s k mask mask和预测标签 p r e d pred pred对应像素相加后再对像素值求和的值,如上图的union所示(只表示到对应元素相加),那么 u n i o n − i n t e r union-inter unioninter的含义可以理解成真实标签的背景部分在预测标签上的预测结果,如上图的union-inter所示,简单来说就是只考虑预测标签针对真实背景的预测效果,默认前景部分完全预测正确,屏蔽了前景不作考虑,因此 u n i o n − i n t e r = T f + P b union-inter=T_f+P_b unioninter=Tf+Pb

T b T_b Tb表示背景位置真实像素求和值(也就是0), P f P_f Pf表示前景位置预测像素求和值, T f T_f Tf表示前景位置真实像素求和值, P b P_b Pb表示背景位置预测像素求和值。
注意!!!!区分背景位置预测像素和预测背景像素俩个概念!!!前者是真实背景像素位置可能真确预测为背景,也可能错误预测成前景;后者则是对预测一个像素位置为背景。

解释了 i n t e r inter inter u n i o n − i n t e r union - inter unioninter的含义, i o u iou iou也可以表示成 i o u = T b + P f T f + P p iou = \frac{{{T_b} + {P_{\rm{f}}}}}{{{T_f} + {P_p}}} iou=Tf+PpTb+Pf T b T_b Tb T f T_f Tf是固定不变的,那么 ℓ i o u \ell _{{\rm{iou}}} iou的优化目标就是 P f P_f Pf越来越大且 P b P_b Pb越来越小。
代码位置:train.py

# PM loss function
bce_loss = nn.BCEWithLogitsLoss().cuda(device_ids[0])
iou_loss = loss.IOU().cuda(device_ids[0])
def bce_iou_loss(pred, target):
    bce_out = bce_loss(pred, target)
    iou_out = iou_loss(pred, target)
    loss = bce_out + iou_out
    return loss

代码位置:loss.py

博主为了方便大家理解,小改了下源码,但是没有丝毫影响源码的原始目的。

class IOU(torch.nn.Module):
    def __init__(self):
        super(IOU, self).__init__()
    def _iou(self, pred, target):
        pred = torch.sigmoid(pred)
        # 交集区域
        inter = (pred * target).sum(dim=(2, 3))
        # 并集区域
        union = (pred + target).sum(dim=(2, 3))
        # iou损失
        iou = 1 - (inter / (union- inter))
        return iou.mean()
    def forward(self, pred, target):
        return self._iou(pred, target)

FM聚焦模块损失函数

对于FM模块,希望更多地关注对象的边界、细长区域或孔处等分散注意力区域。因此,使用加权二值交叉熵损失(Binary CrossEntropy Loss,BCE)损失 ℓ w b c e \ell _{{\rm{wbce}}} wbce和加权IoU损失 ℓ w i o u \ell _{{\rm{wiou}}} wiou的输出,即 ℓ f m = ℓ w b c e + ℓ w i o u {\ell _{{\rm{fm}}}} = {\ell _{{\rm{wbce}}}} + {\ell _{{\rm{wiou}}}} fm=wbce+wiou,以迫使FM更加关注可能的分散注意力区域。
ℓ i o u \ell _{{\rm{iou}}} iou在上个章节就进行了说明, ℓ w i o u \ell _{{\rm{wiou}}} wiou大同小异,因此不再具体讲解,本小节主要介绍 ℓ w b c e \ell _{{\rm{wbce}}} wbce ℓ w i o u \ell _{{\rm{wiou}}} wiou中的 w w w权重的产生,在论文中对此并没有详细解释,因此博主根据论文源码绘制以下示意图具体讲解 w w w的作用:

w w w权重是通过对标签 m a s k mask mask进行平均池化操作,再减去 m a s k mask mask,最后取绝对值:
w = 1 + 5 × ∣ A v g P o o l ( m a s k ) − m a s k ∣ w = 1 + 5 \times \left| {\left. {AvgPool(mask) - mask} \right|} \right. w=1+5×AvgPool(mask)mask
为什么这么简单的操作就能让 w w w更加关注可能的分散注意力区域?博主分以下几种情况讨论:

  • 第一种情况:如上图1所示位置,该前景像素位于前景目标的内部,因此不是对象的边界、细长区域或孔处等分散注意力区域,其 w w w权重计算为1,不需要对其做额外加强;
  • 第二种情况:如上图2所示位置,该前景像素是对象的边界,属于分散注意力区域,其 w w w权重计算为4.9,可谓是剧烈加强;
  • 第三种情况:如上图3所示位置,该背景像素是模糊边界,也属于分散注意力区域,其 w w w权重计算为4.3,也是剧烈加强;
  • 第四种情况:如上图4所示位置,该像素是背景,其 w w w权重计算为1,不需要对其做额外加强;

博主绘制的示意图只是为了方便理解,真实的池化核大小不可能只有3×3那么小,源码中使用的池化核大小是31×31。
代码位置:train.py

# FM loss function
structure_loss = loss.structure_loss().cuda(device_ids[0])

代码位置:loss.py

class structure_loss(torch.nn.Module):
    def __init__(self):
        super(structure_loss, self).__init__()

    def _structure_loss(self, pred, mask):
        print(pred.shape)
        # 根据mask标签生成关于mask的权重
        # 根据公式可以知道,越是靠近前景目标边缘的像素,权重可能就越高,而越靠近前景目标的中心的像素权重越低,最低为1
        weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
        # 因为预测标签还要进行加权,暂时需要保留结构,所以损失在每个元素上计算,reduce选择none
        wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
        # 加权的bce
        wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
        pred = torch.sigmoid(pred)
        # 交集区域
        inter = ((pred * mask) * weit).sum(dim=(2, 3))
        # 并集区域
        union = ((pred + mask) * weit).sum(dim=(2, 3))
        # 加权的iou损失
        wiou = 1 - (inter) / (union - inter)
        return (wbce + wiou).mean()
    def forward(self, pred, mask):
        return self._structure_loss(pred, mask)

总结

尽可能简单、详细的介绍PFNet网络中的损失函数模块的结构和代码。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1263150.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面试篇spark(spark core,spark sql,spark 优化)

一:为什么学习spark? 相比较map-reduce框架,spark的框架执行效率更加高效。 mapreduce的执行框架示意图。 spark执行框架示意图 spark的执行中间结果是存储在内存当中的,而hdfs的执行中间结果是存储在hdfs中的。所以在运算的时…

深入理解强化学习——马尔可夫决策过程:贝尔曼期望方程-[举例与代码实现]

分类目录:《深入理解强化学习》总目录 在文章《深入理解强化学习——马尔可夫决策过程:贝尔曼期望方程-[基础知识]》中我们讲到了贝尔曼期望方程,本文就举一个贝尔曼期望方程的具体例子,并给出相应代码实现。 下图是一个马尔可夫…

Harmony OS4开发入门

代码地址: https://gitee.com/BruceLeeAdmin/harmonyos/tree/master 项目目录介绍 ArkTS介绍 简单案例: State times: number 0/*数据类型:stringnumberany: 不确定类型,可以是任意类型*/State msg: string "hello"…

正点原子linux应用编程——提高篇1

在之前的入门篇学习中,都是直接在Ubuntu中进行验证的,对于嵌入式Linux系统来说,也是可以直接移植的,只需要使用嵌入式硬件平台对应的交叉编译工具编译应用程序即可运行。 在嵌入式Linux系统中,编写的应用程序通常需要…

uniapp使用vue3和ts开发小程序获取用户城市定位

这个组件的功能:可以重新定位获取到用户的具体位置,这个是通过getLocation这个api和高德地图的api获取到的,getLocation这个api需要在微信公众平台后台>开发管理> 接口管理里面申请才能使用的,不然无法使用哦,这…

数学老师怎么和家长沟通

作为一名数学老师,与家长建立良好的沟通关系是非常重要的。以下是我个人认为可以帮助与家长有效沟通的建议: 建立良好的第一印象 第一次与家长接触时,要尽可能展现出你的专业素养和热情。在交流中,要表达出你对孩子的关心和重视&…

Android 编译的配置文件:android.mk 和android.bp

Android.bp文件首先是Android系统的一种编译配置文件,是用来代替原来的Android.mk文件的。在Android7.0以前,Android都是使用make来组织各模块的编译,对应的编译配置文件就是Android.mk。在Android7.0开始,Google引入了ninja和kat…

接口文档自动生成工具:详细教程和实用技巧

本篇文章详细教你如何使用 Apifox 的 IDEA 插件实现自动生成接口代码。好处简单总结有以下几点: 自动生成接口文档: 不用手写,一键点击就可以自动 生成文档,当有更新时,点击一下就可以自动同步接口文档;代…

小程序静默授权获取unionid

文章目录 导文文章重点 导文 小程序静默授权获取unionid 文章重点 用wx.login(Object object)放到app.js里面 wx.login({success (res) {console.log(123);if (res.code) {//发起网络请求// wx.request({// url: https://example.com/onLogin,// data: {// code: res.…

ERRO报错

无法下载nginx 如下解决: 查看是否有epel 源 安装epel源 安装第三方 yum -y install epel-release.noarch NGINX端口被占用 解决: 编译安装的NGINX配置文件在/usr/local/ngin/conf 修改端口

AI - Navmesh 寻路

用cocos2dx引擎简单实现了一下navmesh的多边形划分,然后基于划分多边形的a*寻路。以及路径拐点优化算法 用cocos主要是方便使用一些渲染接口和定时器。重点是实现的原理。 首先画了一个带有孔洞的多边形 //多边形的顶点数据Vec2(100, 100),Vec2(300, 200),Vec2(50…

程序员的软件开发帮手,低代码当仁不让

目录 一、低代码是什么? 二、低代码的能力表现 1.提供可视化开发 2.预构建的组件和模板 3.集成的开发和测试工具 4.跨平台兼容性 5.可伸缩性和可扩展性: 跟随互联网信息技术快速发展的脚步,各行各业都在积极拥抱数字化转型。在这个过程中&…

详解STL库—map和set

目录 一、关联式容器 二、键值对 SGI-STL中关于键值对的定义: 三、set 3.1 set的介绍 3.2 set的使用 1.set的模板参数列表​编辑 2. set的构造 3. set的迭代器 4. set的容量 5. set修改操作 6. set的使用举例 四、map 4.1map的介绍 4.2 map的使用 1…

国产操作系统-银河麒麟V10

一、介绍 银河麒麟操作系统隶属于麒麟软件,麒麟软件是专业从事国产操作系统研发和产业化的企业,面向通用和专用领域打造安全创新的国产操作系统产品和相应解决方案,旗下拥有银河麒麟、中标麒麟、星光麒麟三大产品品牌。 麒麟软件官方网站地…

【攻防世界-misc】glance-50

1.得到一个动图 2.使用GIF动态图片分解,多帧动态图分解成多张静态图片_图片工具网页版,将图片定格组合, 由此得到flag值,拼写提交。

卡码网语言基础课 | 15. 链表的基础操作Ⅲ

目录 一、 插入链表的过程 二、 删除链表的过程 三、 打印链表 3.1 判断节点是否处于链尾 3.2 打印链表 3.3 循环体结束,遍历打印 题目: 请编写一个程序,实现以下链表操作:构建一个单向链表,链表中包含一组整数…

c++没有返回值的返回值

上面的函数search没有返回值,因为a不等于1,但是输出的时候会输出6.这恰巧是x的值,如果我们希望a不等于1时返回x,那么这种结果反而是正确的.有时候这种错误的代码可能产生正确的结果反而会加大debug难度 int search(int n) { 00007FF66DB723E0 mov dword ptr [rsp8],e…

【Linux系统编程】进程概念详解(什么是进程?如何查看进程?)

目录 一、前言 二、 什么是进程? 💦引出进程 💦进程的基本概念 💦理解进程 ⭐描述进程--PCB(进程控制块) ⭐组织进程 三、查看进程 💦 通过 ps 命令查看进程 💦 通过 l…

事件代理?

1.什么是事件代理? 事件代理也叫事件委托,只指定一个事件处理程序,就可以管理某一类型得事件。 可以简单理解为,事件代理就是将本应该绑定子元素事件绑定给父元素代理。它的优点就是:减少事件得执行,减少浏…

2023/11/28JAVAweb学习

查找哪个进程占用了该端口号 跳过某一个阶段