深度学习基础入门篇[五]:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测

news2025/1/16 6:51:02

1.交叉熵损失函数

在物理学中,“熵”被用来表示热力学系统所呈现的无序程度。香农将这一概念引入信息论领域,提出了“信息熵”概念,通过对数函数来测量信息的不确定性。交叉熵(cross entropy)是信息论中的重要概念,主要用来度量两个概率分布间的差异。假定 p和 q是数据 x的两个概率分布,通过 q来表示 p的交叉熵可如下计算:

H ( p , q ) = − ∑ x p ( x ) log ⁡ q ( x ) H\left(p,q\right)=-\sum\limits_{x}p\left(x\right)\log q\left(x\right) H(p,q)=xp(x)logq(x)

交叉熵刻画了两个概率分布之间的距离,旨在描绘通过概率分布 q来表达概率分布 p的困难程度。根据公式不难理解,交叉熵越小,两个概率分布 p和 q越接近。

这里仍然以三类分类问题为例,假设数据 x属于类别 1。记数据x的类别分布概率为 y,显然 y=(1,0,0)代表数据 x的实际类别分布概率。记 y ^ \hat{y} y^代表模型预测所得类别分布概率。那么对于数据 x而言,其实际类别分布概率 y和模型预测类别分布概率 y ^ \hat{y} y^的交叉熵损失函数定义为:

c r o s s e n t r y y = − y × log ⁡ ( y ^ ) cross entryy=-y\times\log(\hat{y}) crossentryy=y×log(y^)

很显然,一个良好的神经网络要尽量保证对于每一个输入数据,神经网络所预测类别分布概率与实际类别分布概率之间的差距越小越好,即交叉熵越小越好。于是,可将交叉熵作为损失函数来训练神经网络。

图1 三类分类问题中输入x的交叉熵损失示意图(x 属于第一类)

在上面的例子中,假设所预测中间值 (z1,z2,z3)经过 Softmax映射后所得结果为 (0.34,0.46,0.20)。由于已知输入数据 x属于第一类,显然这个输出不理想而需要对模型参数进行优化。如果选择交叉熵损失函数来优化模型,则 (z1,z2,z3)这一层的偏导值为 (0.34−1,0.46,0.20)=(−0.66,0.46,0.20)。

可以看出, S o f t m a x Softmax Softmax和交叉熵损失函数相互结合,为偏导计算带来了极大便利。偏导计算使得损失误差从输出端向输入端传递,来对模型参数进行优化。在这里,交叉熵与Softmax函数结合在一起,因此也叫 S o f t m a x Softmax Softmax损失(Softmax with cross-entropy loss)。

2.均方差损失(Mean Square Error,MSE)

均方误差损失又称为二次损失、L2损失,常用于回归预测任务中。均方误差函数通过计算预测值和实际值之间距离(即误差)的平方来衡量模型优劣。即预测值和真实值越接近,两者的均方差就越小。

计算方式:假设有 n个训练数据 x i x_i xi,每个训练数据 x i x_i xi 的真实输出为 y i y_i yi,模型对 x i x_i xi的预测值为 y i ^ \hat{y_i} yi^。该模型在 n 个训练数据下所产生的均方误差损失可定义如下:

M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE=\dfrac{1}{n}\sum\limits_{i=1}^n\left(y_i-\hat{y}_i\right)^2 MSE=n1i=1n(yiy^i)2

假设真实目标值为100,预测值在-10000到10000之间,我们绘制MSE函数曲线如 图1 所示。可以看到,当预测值越接近100时,MSE损失值越小。MSE损失的范围为0到∞。

3.CTC损失

3.1 CTC算法算法背景-----文字识别语音等序列问题

CTC 算法主要用来解决神经网络中标签和预测值无法对齐的情况通常用于文字识别以及语音等序列学习领域。举例来说,在语音识别任务中,我们希望语音片段可以与对应的文本内容一一对应,这样才能方便我们后续的模型训练。但是对齐音频与文本是一件很困难的事,如 图1 所示,每个人的语速都不同,有人说话快,有人说话慢,我们很难按照时序信息将语音序列切分成一个个的字符片段。而手动对齐音频与字符又是一件非常耗时耗力的任务

图1 语音识别任务中音频与文本无法对齐

在文本识别领域,由于字符间隔、图像变形等问题,相同的字符也会得到不同的预测结果,所以同样会会遇到标签和预测值无法对齐的情况。如 图2 所示。

图2 不同表现形式的相同字符示意图

总结来说,假设我们有个输入(如字幅图片或音频信号)X ,对应的输出是 Y,在序列学习领域,通常会碰到如下难点:

  • X和 Y都是变长的;

  • X和 Y的长度比也是变化的;

  • X和 Y相应的元素之间无法严格对齐。

3.2 算法概述

引入CTC主要就是要解决上述问题。这里以文本识别算法CRNN为例,分析CTC的计算方式及作用。CRNN中,整体流程如 图3 所示。

图3 CRNN整体流程

CRNN中,首先使用CNN提取图片特征,特征图的维度为 m × T m×T m×T,特征图 x可以定义为:

x = ( x 1 , x 2 , . . . , x T ) x=(x^1,x^2,...,x^T)\quad\text{} x=(x1,x2,...,xT)

然后,将特征图的每一列作为一个时间片送入LSTM中。令 t为代表时间维度的值,且满足 1 < t < T 1<t<T 1<t<T,则每个时间片可以表示为:

x t = ( x 1 t , x 2 t , … , x m t ) x^t=(x_1^t,x_2^t,\ldots,x_m^t) xt=(x1t,x2t,,xmt)

经过LSTM的计算后,使用softmax获取概率矩阵 y,定义为:

y = ( y 1 , y 2 , … , y T ) y=(y^1,y^2,\ldots,y^T) y=(y1,y2,,yT)

经过LSTM的计算后,使用softmax获取概率矩阵 y t y^t yt,定义为:

y t = ( y 1 t , y 2 t , … , y n t ) y^t=(y_1^t,y_2^t,\ldots,y_n^t) yt=(y1t,y2t,,ynt)

n为字符字典的长度,由于 y i t y_i^t yit是概率,所以 Σ i y i t = 1 \Sigma_i y_i^t=1 Σiyit=1 。对每一列 y t y^t yt求 argmax(),就可以获取每个类别的概率。

考虑到文本区域中字符之间存在间隔,也就是有的位置是没有字符的,所以这里定义分隔符 −来表示当前列的对应位置在图像中没有出现字符。用 L L L代表原始的字符字典,则此时新的字符字典 L ′ L′ L为:

L ′ = L ∪ { − } L'=L\cup\{-\} L=L{}

此时,就回到了我们上文提到的问题上了,由于字符间隔、图像变形等问题,相同的字符可能会得到不同的预测结果。在CTC算法中,定义了 B变换来解决这个问题。 B变换简单来说就是将模型的预测结果去掉分割符以及重复字符(如果同个字符连续出现,则表示只有1个字符,如果中间有分割符,则表示该字符出现多次),使得不同表现形式的相同字符得到统一的结果。如 图4 所示。

这里举几个简单的例子便于理解,这里令T为10:

B ( − s − t − a a t i v e ) = s t a t e B ( s s − t − a − t − e ) = s t a t e B ( s s t t − a a t − e ) = s t a t e \begin{array}{c}B(-s-t-aative)=state\\ \\ B(ss-t-a-t-e)=state\\ \\ B(sstt-aat-e)=state\end{array} B(staative)=stateB(sstate)=stateB(ssttaate)=state

对于字符中间有分隔符的重复字符则不进行合并:

B ( − s − t − t s t a t e ) = s t a t e B(-s-t-t state)=state B(sttstate)=state

当获得LSTM输出后,进行 B变换就可以得到最终结果。由于 B变换并不是一对一的映射,例如上边的3个不同的字符都可以变换为state,所以在LSTM的输入为 x的前提下,CTC的输出为 l的概率应该为:

p ( l ∣ x ) = Σ π ∈ B − 1 ( l ) p ( π ∣ x ) p(l|x)=\Sigma_{\pi\in B^{-1}(l)}p(\pi|x) p(lx)=ΣπB1(l)p(πx)

其中, p i pi pi为LSTM的输出向量, π ∈ B − 1 ( l ) \pi\in B^{-1}(l) πB1(l)代表所有能通过 B变换得到 l的 p i pi pi的集合。

而对于任意一个 π,又有:

p ( π ∣ x ) = Π t = 1 T y π t t p(\pi|x)=\Pi_{t=1}^T y_{\pi_t}^t p(πx)=Πt=1Tyπtt

其中, y π t t y_{\pi_t}^t yπtt代表 t时刻 π为对应值的概率,这里举一个例子进行说明:

π = − s − t − a a t t t e y π t t = y − 1 ∗ y s 2 ∗ y − 3 ∗ y t 4 ∗ y − 5 ∗ y a 6 ∗ y a 7 ∗ y t 8 ∗ y t 9 ∗ y e 1 0 \begin{array}{c}\pi=-s-t-aattte\\ y_{\pi_t}^t=y_-^1*y_s^2*y_-^3*y_t^4*y_-^5*y_a^6*y_a^7*y_t^8*y_t^9*y_e^10\\ \end{array} π=staattteyπtt=y1ys2y3yt4y5ya6ya7yt8yt9ye10

不难理解,使用CTC进行模型训练,本质上就是希望调整参数,使得 p ( π ∣ x ) p(\pi\text{}|x) p(πx) 取最大。

具体的参数调整方法,可以阅读以下论文进行了解:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks

4.平衡 L1损失(Balanced L1 Loss)—目标检测

目标检测(object detection)的损失函数可以看做是一个多任务的损失函数,分为分类损失和检测框回归损失:

L p , u , t u , v = L c l s ( p , u ) + λ [ u ≥ 1 ] L l o c ( t u , v ) L_{p,u,tu,v}=L_{cls}(p,u)+\lambda[u\geq1]L_{loc}(t^u,v) Lp,u,tu,v=Lcls(p,u)+λ[u1]Lloc(tu,v)

L c l s L_cls Lcls表示分类损失函数、 L l o c L_loc Lloc表示检测框回归损失函数。在分类损失函数中,p表示预测值,u表示真实值。 t u t_u tu表示类别u的位置回归结果,v是位置回归目标。λ用于调整多任务损失权重。定义损失大于等于1.0的样本为outliers(困难样本,hard samples),剩余样本为inliers(简单样本,easy sample)。

平衡上述损失的一个常用方法就是调整两个任务损失的权重,然而,回归目标是没有边界的,直接增加检测框回归损失的权重将使得模型对outliers更加敏感,这些hard samples产生过大的梯度,不利于训练。inliers相比outliers对整体的梯度贡献度较低,相比hard sample,平均每个easy sample对梯度的贡献为hard sample的30%,基于上述分析,提出了balanced L1 Loss(Lb)。

Balanced L1 Loss受Smooth L1损失的启发,Smooth L1损失通过设置一个拐点来分类inliers与outliers,并对outliers通过一个 m a x ( p , 1.0 ) max(p,1.0) max(p,1.0)进行梯度截断。相比smooth l1 loss,Balanced l1 loss能显著提升inliers点的梯度,进而使这些准确的点能够在训练中扮演更重要的角色。设置一个拐点区分outliers和inliers,对于那些outliers,将梯度固定为1,如下图所示:

Balanced L1 Loss的核心思想是提升关键的回归梯度(来自inliers准确样本的梯度),进而平衡包含的样本及任务。从而可以在分类、整体定位及精确定位中实现更平衡的训练,Balanced L1 Loss的检测框回归损失如下:

L l o c = ∑ i ∈ x , y , w , h L b ( t i u − v i ) L_{loc}=\sum\limits_{i\in x,y,w,h}L_b(t_i^u-v_i) Lloc=ix,y,w,hLb(tiuvi)

其相应的梯度公示如下:

∂ L l o c ∂ w ∝ ∂ L b ∂ t i u ∝ ∂ L b ∂ x \dfrac{\partial L_{loc}}{\partial w}\propto\dfrac{\partial L_b}{\partial t_i^u}\propto\dfrac{\partial L_b}{\partial x} wLloctiuLbxLb

基于上述公式,设计了一种推广的梯度公式为:

∂ L b ∂ x = { α l n ( b ∣ x ∣ + 1 ) , i f ∣ x ∣ < 1 γ , o t h e r w i s e \dfrac{\partial L_b}{\partial x}=\begin{cases}\alpha ln(b|x|+1),if|x|<1\\ \gamma,otherwise\end{cases} xLb={αln(bx+1),ifx<1γ,otherwise

其中, α α α控制着inliers梯度的提升;一个较小的α会提升inliers的梯度同时不影响outliers的值。 γ γ γ来调整回归误差的上界,能够使得不同任务间更加平衡。α,γ从样本和任务层面控制平衡,通过调整这两个参数,从而达到更加平衡的训练。Balanced L1 Loss公式如下:

L b ( x ) = { a b ( b ∣ x ∣ + 1 ) l n ( b ∣ x ∣ + 1 ) − α ∣ x ∣ , i f ∣ x ∣ < 1 γ ∣ x ∣ + C , o t h e r w i s e L_b(x)=\begin{cases}\frac ab(b|x|+1)ln(b|x|+1)-\alpha|x|,if|x|<1\\ \gamma|x|+C,otherwise\end{cases} Lb(x)={ba(bx+1)ln(bx+1)αx,ifx<1γx+C,otherwise

其中参数满足下述条件:

α l n ( b ∣ x ∣ + 1 ) = γ \alpha ln(b|x|+1)=\gamma\quad\text{} αln(bx+1)=γ

默认参数设置:α = 0.5,γ=1.5

Libra R-CNN: Towards Balanced Learning for Object Detection

|x|<1\ \gamma|x|+C,otherwise\end{cases}$

其中参数满足下述条件:

α l n ( b ∣ x ∣ + 1 ) = γ \alpha ln(b|x|+1)=\gamma\quad\text{} αln(bx+1)=γ

默认参数设置:α = 0.5,γ=1.5

Libra R-CNN: Towards Balanced Learning for Object Detection

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/428533.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ITIL社群的内容及作用

官方网站 www.itilzj.com 文档资料: wenku.itilzj.com ITIL是全球范围内最为流行的IT服务管理框架之一&#xff0c;它能够帮助企业提高IT服务质量&#xff0c;提升业务价值。无论你是IT行业的从业者还是对ITIL感兴趣的人士&#xff0c;ITIL之家社群都将为你提供有价值的知识和经…

非关系型数据库---Redis安装与基本使用

一、数据库类型 关系数据库管理系统(RDBMS)非关系数据库管理系统(NoSQL) 按照预先设置的组织机构&#xff0c;将数据存储在物理介质上(即&#xff1a;硬盘上) 数据之间可以做无关联操作 (例如: 多表查询&#xff0c;嵌套查询&#xff0c;外键等) 主流的RDBMS软件&#xff1a;My…

java io流 概念 详解

IO流 当需要把内存中的数据存储到持久化设备上这个动作称为输出&#xff08;写&#xff09;Output操作。 当把持久设备上的数据读取到内存中的这个动作称为输入&#xff08;读&#xff09;Input操作。 因此我们把这种输入和输出动作称为IO操作。 学习目标: 一、文件类&#x…

张程伟:从开源项目到企业级数据库,云和恩墨 MogDB Uqbar 的技术探索与实践...

导语4月8日下午&#xff0c;为期两天的第十二届数据技术嘉年华&#xff08;DTC 2023&#xff09;在北京新云南皇冠假日酒店圆满落下帷幕。大会以“开源融合数字化——引领数据技术发展&#xff0c;释放数据要素价值”为主题&#xff0c;汇聚产学研各界精英到场交流。作为大会的…

电蚊拍欧盟CE认证EMC+LVD测试

电蚊拍&#xff08;Mosquito&#xff09;&#xff0c;主要由高频振荡电路、三倍压整流电路和高压电击网DW三部分组成。工作中&#xff0c;经升压电路在双层电网间产生1850V直流左右的高压电&#xff0c;两电网间的静电场有较强的吸附力&#xff0c;当蚊蝇等害虫接近电网时&…

系统集成路由器OSPF动态、综合路由配置

实验任务&#xff1a;动态路由协议RIP、OSPF协议的内容和特点动态路由RIP、OSPF实验&#xff0c;建立拓扑pc1>>R1>>R2>>R3>>pc2&#xff0c;使pc1与pc2能相互通信&#xff0c;并配置PC端静默接口。熟悉配置vlan间路由技术&#xff1a;多层交换机虚拟接…

落地“旅游+”数字赋能:实现智慧旅游协同创新发展

经济的蓬勃发展&#xff0c;与之带来的是消费水平的不断提升&#xff0c;旅行已经成为我们日常生活中不可缺少的一项。在过去三年间&#xff0c;我们由于或这或那的原因&#xff0c;并无法真正实现一场说走就走的旅程。大家在过去的三年算是憋狠了&#xff0c;所以在今年&#…

计算专题(小计算题)

考点&#xff1a; 1.沟通渠道的总量为 n*(n-1)/2&#xff0c;其中 n 代表干系人的数量。 2.决策树计算/自制和外购决策-----EMV。 3.盈亏平衡计算。&#xff08;刚好不亏也不赚&#xff09; 【案例】假设某IT服务企业&#xff0c;其固定成本为30万元&#xff0c;每项服务的变…

用于测试FDIA在现实约束下可行性的FDIA建模框架(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f468;‍&#x1f4bb;4 Matlab代码 &#x1f4a5;1 概述 信息通信技术的发展和智能设备的引入使电力系统逐渐演变为电力信息物理系统&#xff0c;而信息层与物理层之间的深度耦合也加剧…

HashMap死循环详解

目录 一、数据插入原理 二、导致死循环的原因 三、解决方案 一、数据插入原理 由于JDK1.7中&#xff0c;HashMap的底层存储结构采用的是数组链表的方法 插入数据时候采用的是头插法 二、导致死循环的原因 此时线程T1&#xff0c;T2节点同时指向A节点&#xff0c;同时线程T1…

Spring Boot 整合 Swagger 教程详解

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…

小白必看,吐血整理Facebook新手指南(二)

上篇文章咱们讲了关于FB广告的类型&#xff0c;今天咱们再来详细讲下如何设置FB广告、注意事项以及如何借助强大的工具&#xff08;SaleSmartly、ss客服&#xff09;监控广告效果、承接广告流量。话不多说&#xff0c;直接上干货选择你的目标 首先&#xff0c;前往您的广告管理…

虚拟化服务器和普通服务器的区别

随着云计算技术的快速普及&#xff0c;虚拟化技术作为其中的一项核心技术&#xff0c;也越来越受到了企业和个人用户的关注。虚拟化服务器相较于传统的物理服务器&#xff0c;具备更高的灵活性和可扩展性&#xff0c;但同时也存在一些不足之处。那么虚拟化服务器的优缺点有哪些…

[STM32F103C8T6]基于stm32的循迹,跟随,避障智能小车

目录 1.小车驱动主要是通过L9110S模块来驱动电机 motor.c 2.我们可以加入串口控制电机驱动(重写串口接收回调函数&#xff0c;和重定向printf) Uart.c main.c 3.点动功能 uart.c main.c 为什么使用的是HAL_Delay()要设置滴答定时器的中断优先级呢&#xff1f; 4.小车…

如何在 Mac上运行 Windows程序?

在Mac 上运行 Windows的工具 在 Mac 上运行 Windows-无需重启即可在您的 Intel 或 Apple M 系列 Mac 上运行 Windows的工具来了,非常强悍和使用,有需要的朋友可以参考一下。 主要功能 运行快速、操作简单、功能强大的应用程序,无需重启即可在您的 Intel 或 Apple M 系列 M…

基于 VITA57.1 的 2 路 125MSPS AD 采集、2 路 250MSPS DA 回放 FMC 子卡模块

板卡概述 FMC150_V30 是一款基于 VITA57.1 规范的 2 路 125MSPS 采样率 16 位分辨率 AD 采集、2 路 250MSPS 采样率 16 位分辨率 DA 回放 FMC 子卡模块。该模块遵循 VITA57.1 规范&#xff0c;可直接与符合 VITA57.1 规范的 FPGA 载卡配合使用&#xff0c;板卡 ADC 器件采用 AD…

接口自动化两大神器:正则提取器和jsonpath提取器

一、前言 在开展接口测试的过程中&#xff0c;我们会发现很多接口需要依赖前面的接口&#xff0c;需要我们动态从前面的接口返回中提取数据&#xff0c;也就是我们通常说的关联。 关联通俗来讲就是把上一次请求的返回内容中的部分截取出来保存为参数&#xff0c;用来传递给下…

迅为龙芯2K0500全国产开发板

目录 龙芯2K0500处理器 动态电源管理 低功耗技术 产品开发更快捷 全国产设计方案 2K0500核心板 邮票孔连接 丰富接口 高扩展性 系统全开源 品质保障 行业应用 龙芯2K0500处理器 迅为iTOP-LS2K0500开发采用龙芯LS2K0500处理器&#xff0c;基于龙芯自主指令系统&#x…

托福听力专项 // Unit1 Listening for Main Ideas //共5篇conversations

目录 I a history class II a student & a librarian III a student & a professor IV a student & a bookstore clerk I a history class its definition II a student & a librarian (1) The librarian was happy to help and explained to the studen…

软件工程part02-软件需求与需求规约

文章目录课程简介考试大纲软件需求与需求规约2.0 可行性分析2.1 需求概述需求分类2.2 需求工程步骤2.3 需求获取2.4 需求规约2.4.1 逻辑模型和物理模型2.4.2 需求分析过程示意2.4.3 结构化分析模型2.4.4 E-R图是数据建模的基础2.4.5 数据流图2.4.5.3 数据流命名规则2.4.5.6 DFD…