【深度学习 | 计算机视觉】Focal Loss原理及其实践(含源代码)

news2024/11/18 15:37:02

参考文献:

https://www.jianshu.com/p/437ce8ed0413

文章目录

  • 一、导读
  • 二、Focal Loss 原理
  • 三、实验对比
    • 3.1 使用交叉熵损失函数
    • 3.2 使用Focal Loss 损失函数
    • 3.3 总结

一、导读

Focal Loss 是一个在交叉熵(CE)基础上改进的损失函数,来自ICCV2017的Best student paper—Focal Loss for Dense Object Detection。

Focal Loss的提出源自图像领域中目标检测任务中样本数量不平衡的问题,并且这里所谓的不平衡性跟平常理解的是有所区别的,它还强调样本的难易性。尽管Focal Loss始于目标检测场景,其实它可以应用到很多其他任务场景,只要符合它的问题背景,就可以试试,就有意想不到的效果。

二、Focal Loss 原理

在引入Focal Loss公式前,我们以源paper中目标检测的任务来说:目标检测器通常会产生高达100k的候选目标,只有极少数是正样本,正负样本数量非常不平衡。

在计算分类的时候常用的损失—交叉熵(CE)的公式如下:
C E ( p , y ) = { − log ⁡ ( p ) i f   y = 1 − log ⁡ ( 1 − p ) o t h e r w i s e CE(p,y)=\left\{ \begin{array}{rcl} -\log(p) & & {if \ y=1}\\ -\log(1-p) & & {otherwise} \end{array} \right. CE(p,y)={log(p)log(1p)if y=1otherwise

其中 y y y取值{1,-1},代表正负样本, p p p为模型预测的label概率,通常 p > 0.5 p>0.5 p>0.5就判断为正样本,否则为负样本。论文中为了方便展示,重新定义了 p t p_t pt
p t = { p i f   y = 1 1 − p o t h e r w i s e p_t=\left\{\begin{array}{rcl} p & & {if \ y=1}\\ 1-p & & {otherwise}\end{array} \right. pt={p1pif y=1otherwise
这样CE函数就可以表示为: C E ( p , y ) = C E ( p t ) = − log ⁡ ( p t ) CE(p,y)=CE(p_t)=-\log(p_t) CE(p,y)=CE(pt)=log(pt)

在CE的基础上,为了解决正负样本不平衡性,有人提出一种带权重的CE函数:
C E ( p t ) = − α t log ⁡ ( p t ) CE(p_t)=-\alpha_t\log(p_t) CE(pt)=αtlog(pt)
其中当: y = 1 , α t = α ; y = − 1 , α t = 1 − α y=1,\alpha_t=\alpha;y=-1,\alpha_t=1-\alpha y=1,αt=α;y=1,αt=1α,参数 α \alpha α为控制正负样本的权重,取值范围为[0, 1]。

尽管这是一种很简单的解决正负样本不平衡的方案,但它还没有真正达到paper中作者想解决的问题:因为正负样本中也有难易之分,认为模型应该更聚焦在难样本的学习上。如下图,按正负,难易,可将样本分为四个维度,其实上面带权重的CE函数,只是解决了正负问题,并没有解决难易问题。

img
那么怎么来衡量一个样本的难易程度,更何况真实数据也没有这个标记。其实,这里的样本难易是用模型来判断的,就正样本集合来说,如果一个样本预测的 p = 0.9 p=0.9 p=0.9,一个样本预测的 p = 0.6 p=0.6 p=0.6,明显前一个样本更容易学习,或者说特征更明显,是易样本。这样也就是说,预测的概率越接近于1或者0的样本,就越是容易学习的样本,相反,越是集中于0.5左右的样本,就是难样本。在sigmoid函数上,可以按下图的方式展示样本的难易之分。

img

怎么让模型对难易样本也有区分性的学习,也是说聚焦程度不同。模型应该花更多精力在难样本的学习上,而减少精力在易样本的学习,之前的CE函数,以及带权重的CE函数,都是将难样本、易样本等同看待的。这样就引出Focal Loss的表达形式:
F L ( p t ) = − ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t)=-(1-p_t)^{\gamma}\log(p_t) FL(pt)=(1pt)γlog(pt)
其中 γ \gamma γ为调节因子,取值为[0, 5],当 γ = 0 \gamma=0 γ=0,就等同于CE函数; γ \gamma γ值越大,表示模型在难易样本上聚焦的更厉害。下图是不同参数下表现形式:

img

结合上图与公式,可以看出,当 p t p_t pt趋近于1时,权重 ( 1 − p t ) γ (1-p_t)^{\gamma} (1pt)γ趋近于0,对总损失贡献几乎没有影响,意味着模型较少对这类样本的学习;比如,在正样本集合中, γ = 2 \gamma=2 γ=2,当一样本 p t = 0.6 p_t=0.6 pt=0.6,当一样本 p t = 0.7 p_t=0.7 pt=0.7,二者相对来说,前者是难样本,后者是易样本,反映在Focal Loss上,前者的对总损失贡献权重为0.16,后者0.09,明显难样本贡献权重更大,模型也就会更将聚焦对其学习。同理,负样本中一样。

但是上面的Focal Loss公式只是体现了难易样本的区分,没有区分正负。这样就引出了完整版的Focal Loss表达形式:
F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) FL(p_t)=-\alpha_t(1-p_t)^{\gamma}\log(p_t) FL(pt)=αt(1pt)γlog(pt)
这样Focal Loss既能调整正负样本的权重,又能控制难易分类样本的权重。paper中通过实验验证,默认 γ = 2 \gamma=2 γ=2 α t = 0.25 ( y = 1 ) \alpha_t=0.25(y=1) αt=0.25(y=1)。在这里 α t \alpha_t αt取值上可能会有疑问,理论上正样本权重更大些,取0.75,而paper实验结果给的是0.25。

自己的理解:主要原因是 γ = 2 \gamma=2 γ=2,而大部分负样本的 p < 0.1 p<0.1 p<0.1,导致负样本的贡献权重还小于正样本贡献的权重,本意是想调高正样本的贡献权重,但这样就有点调的过大了,所以 α t = 0.25 ( y = 1 ) \alpha_t=0.25(y=1) αt=0.25(y=1)就有点反过来提高下负样本的权重。所以在最终版中,不能理解 α t \alpha_t αt就是完全来调节正负样本的权重的,而是要结合 α t ( 1 − p t ) γ \alpha_t(1-p_t)^{\gamma} αt(1pt)γ一起来看。

三、实验对比

使用标准的二进制交叉熵损失函数和Focal Loss函数分别训练同一个模型的过程,主要分为以下几步:

  1. 加载MNIST数据集,仅保留数字2的样本作为正样本,其他作为负样本。
  2. 定义Focal Loss损失函数。其中alpha和gamma是超参数。
  3. 构建简单的全连接网络模型。
  4. 首先用二进制交叉熵损失函数编译模型,并训练。
  5. 然后用定义的Focal Loss函数编译模型,其他保持不变,并训练。
  6. 这样可以在同一个模型上分别观察不同损失函数的训练效果。

Focal Loss的目的是降低易分类样本的loss值,让模型更集中优化硬分类样本。

所以这段代码通过直接比较,演示了不同损失函数对模型训练效果的影响,是一个很好的实验示例。

通过输出的训练过程和结果,可以直观分析Focal Loss相比于标准交叉熵在这一任务上的效果。

3.1 使用交叉熵损失函数

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import tensorflow.keras.backend as K

# load dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.reshape(60000, 784).astype('float32') / 255
X_test = X_test.reshape(10000, 784).astype('float32') / 255

y_train = np.array([1 if d == 2 else 0 for d in y_train])
y_test = np.array([1 if d == 2 else 0 for d in y_test])
#定义focal loss
def focal_loss(gamma = 2., alpha = .25):
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon() + pt_1))\
            - K.sum((1 - alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
    return focal_loss_fixed

#build model
inputs = keras.Input(shape = (784,), name = 'mnist_input')
h1 = layers.Dense(64, activation = 'relu')(inputs)
outputs = layers.Dense(1, activation = 'sigmoid')(h1)
model = tf.keras.Model(inputs, outputs)

#以平方差损失函数来编译模型进行训练
model.compile(optimizer = keras.optimizers.RMSprop(),
             loss = keras.losses.BinaryCrossentropy(),
             metrics = ['accuracy'])

#training
history = model.fit(X_train, y_train, batch_size = 64, epochs = 10,
         validation_data = (X_test, y_test))

我们的训练过程为:

Epoch 1/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0604 - accuracy: 0.9821 - val_loss: 0.0331 - val_accuracy: 0.9907
Epoch 2/10
938/938 [==============================] - 1s 997us/step - loss: 0.0274 - accuracy: 0.9916 - val_loss: 0.0241 - val_accuracy: 0.9935
Epoch 3/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0203 - accuracy: 0.9940 - val_loss: 0.0225 - val_accuracy: 0.9937
Epoch 4/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0163 - accuracy: 0.9948 - val_loss: 0.0208 - val_accuracy: 0.9942
Epoch 5/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0137 - accuracy: 0.9958 - val_loss: 0.0181 - val_accuracy: 0.9949
Epoch 6/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0114 - accuracy: 0.9967 - val_loss: 0.0168 - val_accuracy: 0.9949
Epoch 7/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0099 - accuracy: 0.9973 - val_loss: 0.0192 - val_accuracy: 0.9942
Epoch 8/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0085 - accuracy: 0.9974 - val_loss: 0.0180 - val_accuracy: 0.9945
Epoch 9/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0074 - accuracy: 0.9981 - val_loss: 0.0180 - val_accuracy: 0.9942
Epoch 10/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0061 - accuracy: 0.9982 - val_loss: 0.0150 - val_accuracy: 0.9954

3.2 使用Focal Loss 损失函数

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import tensorflow.keras.backend as K

# load dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.reshape(60000, 784).astype('float32') / 255
X_test = X_test.reshape(10000, 784).astype('float32') / 255

y_train = np.array([1 if d == 2 else 0 for d in y_train])
y_test = np.array([1 if d == 2 else 0 for d in y_test])
#定义focal loss
def focal_loss(gamma = 2., alpha = .25):
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon() + pt_1))\
            - K.sum((1 - alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
    return focal_loss_fixed

#build model
inputs = keras.Input(shape = (784,), name = 'mnist_input')
h1 = layers.Dense(64, activation = 'relu')(inputs)
outputs = layers.Dense(1, activation = 'sigmoid')(h1)
model = tf.keras.Model(inputs, outputs)

#以Focal Loss损失函数来编译模型进行训练
model.compile(optimizer = keras.optimizers.RMSprop(),
             loss = [focal_loss(alpha = .25, gamma = 2)],
             metrics = ['accuracy'])

#training
history = model.fit(X_train, y_train, batch_size = 64, epochs = 10,
         validation_data = (X_test, y_test))

我们的训练过程为:

Epoch 1/10
938/938 [==============================] - 1s 1ms/step - loss: 0.3401 - accuracy: 0.9829 - val_loss: 0.2087 - val_accuracy: 0.9884
Epoch 2/10
938/938 [==============================] - 1s 1ms/step - loss: 0.1733 - accuracy: 0.9920 - val_loss: 0.1711 - val_accuracy: 0.9925
Epoch 3/10
938/938 [==============================] - 1s 1ms/step - loss: 0.1355 - accuracy: 0.9942 - val_loss: 0.1642 - val_accuracy: 0.9941
Epoch 4/10
938/938 [==============================] - 1s 1ms/step - loss: 0.1128 - accuracy: 0.9952 - val_loss: 0.1766 - val_accuracy: 0.9904
Epoch 5/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0967 - accuracy: 0.9959 - val_loss: 0.1334 - val_accuracy: 0.9948
Epoch 6/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0827 - accuracy: 0.9968 - val_loss: 0.1906 - val_accuracy: 0.9952
Epoch 7/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0738 - accuracy: 0.9970 - val_loss: 0.1455 - val_accuracy: 0.9949
Epoch 8/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0652 - accuracy: 0.9975 - val_loss: 0.1504 - val_accuracy: 0.9946
Epoch 9/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0562 - accuracy: 0.9979 - val_loss: 0.1327 - val_accuracy: 0.9943
Epoch 10/10
938/938 [==============================] - 1s 1ms/step - loss: 0.0514 - accuracy: 0.9980 - val_loss: 0.1649 - val_accuracy: 0.9943

3.3 总结

从结果可以看出,虽然在该数据集上二者提升效果并不大,但Focal Loss在每轮上都优于CE的训练效果,所以还是能体现Focal Loss的优势,如果在其他更不平衡的数据集上,应该效果更好。

不管在CV,还是NLP领域,该损失函数值得大家去尝试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/756505.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java正则表达式MatchResult的接口、Pattern类、Matcher类

Java正则表达式MatchResult的接口 java.util.regex.MatchResult接口表示匹配操作的结果。 此接口包含用于确定与正则表达式匹配的结果的查询方法。可以看到匹配边界&#xff0c;组和组边界&#xff0c;但不能通过MatchResult进行修改。 接口声明 以下是java.util.regex.Matc…

spring复习:(34)配置文件的方式创建ProxyFactoryBean

一、配置文件 <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xmlns:c"http://www.springframework.org/s…

vscode 无法格式化python代码、无法格式化C++代码(vscode格式化失效)另一种解决办法:用外部工具yapf格式化(yapf工具)

文章目录 我真的解决方法&#xff1a;用yapfyapf工具使用方法示例格式化单个文件&#xff08;格式化前先用-d参数预先查看格式化更改内容&#xff0c;以决定是否要更改&#xff09;格式化某个目录递归格式化某个目录 我真的 神马情况&#xff0c;我的vscode死活不能格式化pyth…

路径规划算法:基于减法平均优化的路径规划算法- 附代码

路径规划算法&#xff1a;基于减法平均优化的路径规划算法- 附代码 文章目录 路径规划算法&#xff1a;基于减法平均优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要&#xff1a;本文主要介绍利用智能优化…

用Python自动化处理Excel表格详解

Excel表格基础知识 Excel表格可以帮助用户创建、编辑、格式化和计算数据&#xff0c;并生成各种图表和报表。Excel表格通常用于商业、金融、科学、教育等领域。 Excel表格的常用操作 Excel表格的常用操作包括插入、删除、移动、复制、粘贴、排序和筛选、图表等。这些操作可以…

node操作MySQL数据库

本文节选自我的博客&#xff1a;node 操作 MySQL 数据库 &#x1f496; 作者简介&#xff1a;大家好&#xff0c;我是MilesChen&#xff0c;偏前端的全栈开发者。&#x1f4dd; CSDN主页&#xff1a;爱吃糖的猫&#x1f525;&#x1f4e3; 我的博客&#xff1a;爱吃糖的猫&…

集群基础4——haproxy负载均衡mariadb

文章目录 一、环境说明二、安装配置mariadb三、安装配置haproxy四、验证 一、环境说明 使用haproxy对mysql多机单节点进行负载均衡。 主机IP角色安装服务192.168.161.131后端服务器1mariadb&#xff0c;3306端口192.168.161.132后端服务器2mariadb&#xff0c;3306端口192.168.…

【2023 年第二届钉钉杯大学生大数据挑战赛初赛】 初赛 A:智能手机用户监测数据分析 问题一Python代码分析

2023 年第二届钉钉杯大学生大数据挑战赛初赛 初赛 A&#xff1a;智能手机用户监测数据分析 问题一Python代码分析 1 题目 2023 年第二届钉钉杯大学生大数据挑战赛初赛题目 初赛 A&#xff1a;智能手机用户监测数据分析 一、问题背景 近年来&#xff0c;随着智能手机的产生&a…

STM32F10x外部中断/事件控制器(EXTI)应用

往期文章&#xff1a; STM32F1x固件库函数学习笔记&#xff08;一&#xff09; 文章目录 一、EXTI简介二、EXTI初始化结构体详解三、外部中断&#xff08;EXTI&#xff09;编程要点及例程参考文献 一、EXTI简介 外部中断/事件控制器&#xff0c;简称&#xff1a;EXTI&#x…

Jenkins打包、发布、部署

目录 前言 一、安装jdk 二、安装maven 三、安装git 四、安装jenkins 五、访问jenkins 六、创建用户 七、配置jenkins 八、执行 总结 前言 服务器&#xff1a;CentOS 7.9 64位 jdk&#xff1a;1.8 maven&#xff1a;3.9.1 git&#xff1a;git version 1.8.3.1 jenkins&a…

计算机中的数制与编码(二进制转换)

一、进制表示 1. 十进制表示 使用&#xff08;0&#xff0c;1&#xff0c;2&#xff0c;…&#xff0c;9&#xff09;十位数字表示&#xff0c;十进制运算时逢十进一。 2. 二进制表示 使用(0&#xff0c;1)两个数字表示&#xff0c;二进制运算时逢二进一。 3. 十六进制表示…

AIGC文生图:stable-diffusion-webui部署及使用

1 stable-diffusion-webui介绍 Stable Diffusion Web UI 是一个基于 Stable Diffusion 的基础应用&#xff0c;利用 gradio 模块搭建出交互程序&#xff0c;可以在低代码 GUI 中立即访问 Stable Diffusion Stable Diffusion 是一个画像生成 AI&#xff0c;能够模拟和重建几乎…

宝塔面板清理

查看磁盘使用情况时发现/dev/sda1满了&#xff0c;重启服务器也不行&#xff0c;瞎折腾了半天&#xff0c;才发现是宝塔的回收站占了较大的磁盘&#xff0c;于是按以下操作清理了下&#xff0c;就可以了 1、清除系统监控记录。打开宝塔面板后台&#xff0c;找到监控&#xff0c…

模拟面试2

1.说一说ArrayList的实现原理&#xff1f; ArrayList底层基于数组实现&#xff0c;内部封装了Object类型的数组&#xff0c;实现了list接口&#xff0c;通过默认构造器创建容器时&#xff0c;该数组被初始化为一个空数组&#xff0c;首次添加数据时再将其初始化为容量为10的数组…

变量生命符thread_local

thread_local是c11为线程安全引进的变量声明符。 thread_local是一个存储器指定符&#xff1a; 所谓存储器指定符&#xff0c;其作用类似命名空间&#xff0c;指定了变量名的存储期以及链接方式。同类型的关键字还有&#xff1a; static&#xff1a;静态或者线程存储期&…

2.我的第一个 JAVA 程序Helloword

对象&#xff1a;对象是类的一个实例&#xff0c;有状态和行为。例如&#xff0c;一条狗是一个对象&#xff0c;它的状态有&#xff1a;颜色、名字、品种&#xff1b;行为有&#xff1a;摇尾巴、叫、吃等。类&#xff1a;类是一个模板&#xff0c;它描述一类对象的行为和状态。…

前端videojs实现m3u8格式的直播

一、安装 npm install --save-dev video.js 二、引入 import videojs from "video.js"; import "video.js/dist/video-js.css"; 三、template 由于此处客户需要全屏至指定框大小&#xff0c;而不是全屏整个屏幕所以没用插件自带的全屏控件 隐藏自带全屏…

Unity 2DJoint 物理关节功能与总结

本文将以动图方式展示每个2D物理关节的效果&#xff0c;并解析部分重要参数的作用以及常见调配方式。 1.Distance Joint 2D&#xff08;距离关节&#xff09; 顾名思义是距离关节&#xff0c;以下为启用EnableCollision前后 关节使得两物体保持一定的距离&#xff0c;如果旋…

Apache (二十一)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、概述 二、安装 1. yum安装 2. 编译安装 三、 目录结构 1. yum安装 2. 编译安装 四、虚拟主机头配置 1. 基本配置 2. 实现方式 五、配置文件语法检查 六、 …

MySQL [环境配置]

MySQL [环境配置] MySQL的下载sqlyog的下载 熟悉老陈的人, 都清楚我不喜欢写这些环境配置的博客 那为啥这次要写一下MySQL的环境配置呢? 因为我被这一个小小的环境配置困扰了很长时间, 淋过雨的人都想为别人撑一把伞, 我不希望我的铁汁们也被这个问题困扰 MySQL的下载 MySQL下…