【OCR】CTC loss原理

news2024/11/22 18:23:06

1 CTC loss出现的背景

在图像文本识别、语言识别的应用中,所面临的一个问题是神经网络输出与ground truth的长度不一致,这样一来,loss就会很难计算,举个例子来讲,如果网络的输出是”-sst-aa-tt-e’', 而其ground truth为“state”,那么像之前经常用的损失函数如cross entropy便都不能使用了,因为这些损失函数都是在网络输出与ground truth的长度一致情况下使用的。除了长度不一致的情况之外,还有一个比较难的点在于有多种情况的输出都对应着ground truth,根据解码规则(相邻的重复字符合并,去掉blank), path1: “-ss-t-a-t-e-” 和path2: "–stt-a-tt-e"都可以解码成“state”,与ground truth对应, 也就是many-to-one。为了解决以上问题,CTC loss就产生啦~
2 CTC loss原理

2.1 前序

在说明原理之前,首先要说明一下CTC计算的对象:softmax矩阵,通常我们在RNN后面会加一个softmax层,得到softmax矩阵,softmax矩阵大小是timestep*num_classes, timestep表示的是时间序列的维度,num_class表示类别的维度。

import numpy as np
ts = 12
num_classes = 26+1 #26 for the number of english character, 1 for blank
rnn_output = np.random.random((ts, 16))#16 for hidden node number
w = np.random.random((16,num_classes))
logits = np.matmul(rnn_output,w)#logits: ts*num_classes=[12,27]
#calculate softmax matrix
maxvalue = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits-maxvalue) #minus maxvalue for avoiding overflow
exp_sum = np.sum(exp, axis=1, keepdims=True)
y = softmax = exp/exp_sum #softmax:ts*num_classes=[12,27]

2.2 forward-backward计算

其实呢,整体过程可以看做是对输入的y也就是softmax做了相应的映射得到解码结果,在希望解码结果尽量正确的情况下(使用概率来衡量),对网络的参数进行梯度下降。
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

只有在timestep=7时为a的路径才会使用​

进行路径的分数计算,所以求偏导的时候只对这部分路径求取就可以啦

path1:“-ss-t-a-t-e-” 第7个timestep为a, path2: "–stt-a-tt-e"第7个timestep也为a, 以a为中点,将这两条路径分别分成两段。

path1_forward: “-ss-t-” path1_backward: “-t-e-”

path2_forward: “–stt-” path2_backward: “-tt-e”

​你也会发现 path1_forward+“a”+path2_backward也能够解码成正确的”state", 我们使用path3来表示该路径 , 同样的, path2_forward+“a”+path1_backward也可以解码成正确的“state",我们使用path4表示该路径

在下式中我们考虑​中仅仅包含path1,path2, path3, path4

在这里插入图片描述

在这里插入图片描述

其中​表示的是解码后​的长度。先看forward部分。

2.2.1 forward部分


在这里插入图片描述

这个公式计算的是所有能够解码成​的概率,

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

公式可能一下子不能理解透,举个例子好啦,先看上面的那种情况,也就是特殊情况下的递推公式:

在这里插入图片描述

def forward(y, labels):
    T,C = y.shape #T: timestep  
    L = len(labels)
    alpha=np.zeros([T,L])
    alpha[0,0]=y[0,labels[0]]
    alpha[0,1]=y[0,labels[1]]
    for t in range(1,T):
        for i in range(L):
            s= labels[i]
            a = alpha[t-1,i]
            if i-1>=0:
                a += alpha[t-1,i-1]
            if i-2>=0 and s!=0 and s!=labels[i-2]:
                a +=alpha[t-1,i-2]
            alpha[t,i]=a*y[t,s]
    return alpha
  
labels = [0, 19, 0, 20, 0, 1, 0, 20, 0, 5, 0]
alpha = forward(y,labels)

​ 就像刚刚所说,末尾带有blank和不带有blank都是正确的,“-s-t-a-t-e-“和”-s-t-a-t-e"都可以正确解码,所以

p = alpha[-1,lables[-1]]+alpha[-1,lables[-2]]

2.2.2 backward部分

forward讲清楚之后, backward快速的过一遍就好啦
在这里插入图片描述

这个公式计算的是所有能够解码成​的概率,


在这里插入图片描述

上面三个式子是说第T个timestep的解码成”blank“的概率是在这里插入图片描述

​, 解码成​中第一个字符的概率是在这里插入图片描述

​, 其他的字符的概率为0, 可以这样理解,如果路径能够解码成正确的”state", 那么第T个timestep的肯定是blank或者"e", 只有这样才能解码正确。 我们可以得到与forward相似的递推式:

在这里插入图片描述

套用上面forward的方式去理解,应该不难的~

def backward(y, labels):
    T,C = y.shape #T: timestep  
    L = len(labels)
    beta=np.zeros([T,L])
    beta[-1,-1]=y[-1,labels[-1]]
    beta[-1,-2]=y[-1,labels[-2]]
    for t in range(T-2-1,-1):
        for i in range(L):
            s= labels[i]
            b = beta[t+1,i]
            if i+1<L:
                b += beta[t+1,i+1]
            if i+2<L and s!=0 and s!=labels[i+2]:
                b +=beta[t+1,i+2]
            beta[t,i]=b*y[t,s]
    return beta
  
labels = [0, 19, 0, 20, 0, 1, 0]
beta = backward(y,labels)

2.3 梯度

求了上面的forward和backward之后,就可以求解梯度啦

根据在这里插入图片描述

可以得到
在这里插入图片描述

因为
在这里插入图片描述

所以对​求导的话, 仅有当​为类别k的那一项不为0, 其余项的偏导都为0
在这里插入图片描述

def gradient(y,labels):
    T,C = y.shape
    L = len(labels)
    alpha = forward(y,labels)
    beta = backward(y,labels)
    p = alpha[-1,-1]+alpha[-1,-2]
    gradient = np.zeros([T,V])
    for t in range(T):
        for c in range(C):
            lab = [idx for idx, item in enumerate(labels) if item == c]
            for i in lab:
                gradient[t, s] += alpha[t, i] * beta[t, i]
            gradient[t,c]/=-(y[t,c]**2)
    return gradient3

3 CTC loss优缺点

优点:在文本识别和语言识别领域中,能够比较灵活地计算损失,进行梯度下降

缺点:存在假设前提即每个lable相互独立, 因此可以计算路径的概率,才有了接下来的推导过程,但是在很多情况下上下文的label是有关联的,CTC loss很难考虑这一点,不过这些可以通过引入语言模型解码来解决啦~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/425760.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM:并发的可达性分析

当前主流编程语言的垃圾收集器基本上都是依靠可达性分析算法来判定对象是否存活的&#xff0c;可达性分析算法理论上要求全过程都基于一个能保障一致性的快照中才能够进行分析&#xff0c;这意味着必须全程冻结用户线程的运行。 在根节点枚举这个步骤中&#xff0c;由于 GC Ro…

0303Kruskal算法和小结-最小生成树-图-数据结构和算法(Java)

1 算法概述 定义。按照边的权重顺序&#xff08;从小到大&#xff09;&#xff0c;将边加入最小生成树中。加入的边不会与已经加入的边构成环&#xff0c;知道树中含有V-1条边为主。这些黑色的边逐渐由一片森林合并为一棵树&#xff0c;也就是最小生成树。这种计算方法被称为Kr…

Redis Lua沙盒绕过命令执行(CVE-2022-0543)

一、描述 影响范围&#xff1a;Debian系得linux发行版本Ubuntu Debian系得linux发行版本 其并非Redis本身漏洞&#xff0c;形成原因在于系统补丁加载了一些redis源码注释了的代码 揭露时间&#xff1a;2022.3.8 二、原理 redis在用户连接后可以通过eval命令执行Lua脚本&#x…

GAN网络系列博客(三):不受坐标限制的GAN(StyleGAN3)

目录 1. 概述 2. 连续信号分析 2.1 等变网络层 3. 具体实现 3.1 傅里叶特征和基础简化 3.2 根据连续插值进行的步骤重建 4.实验 总结 Reference 关于StyleGAN系列博客的收尾之作&#xff0c;StyleGAN3已经鸽了太久了。说实话&#xff0c;前两篇StyleGAN说的什么内容&#xff0…

Web 攻防之业务安全:账号安全案例总结.

Web 攻防之业务安全&#xff1a;账号安全案例总结 业务安全是指保护业务系统免受安全威胁的措施或手段。广义的业务安全应包括业务运行的软硬件平台&#xff08;操作系统、数据库&#xff0c;中间件等&#xff09;、业务系统自身&#xff08;软件或设备&#xff09;、业务所提供…

[ 应急响应篇基础 ] 日志分析工具Log Parser配合login工具使用详解(附安装教程)

&#x1f36c; 博主介绍 &#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 _PowerShell &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【数据通信】 【通讯安全】 【web安全】【面试分析】 &#x1f389;点赞➕评论➕收藏 养成习…

HCIP-6.5BGP路由聚合原理及配置

HCIP-6.5BGP路由聚合原理及配置1、BGP路由聚合原理1.1、自动汇总1.2、手动聚合1.2.1、利用静态路由进行聚合1.2.2、BGP的专有工具Aggregate1、BGP路由聚合原理 在实际运行的网络中&#xff0c;BGP路由表十分庞大&#xff0c;对于设备而言是很大的负担&#xff0c;同时使发生路…

Flutter Flex(Row Column,Expanded, Stack) 组件

前言 这个Flex 继承自 MultiChildRenderObjectWidget&#xff0c;所以是多子布局组件 class Flex extends MultiChildRenderObjectWidget {} Flex 的子组件就是Row 和 Column , 之间的区别就是Flex 的 direction 设置不同。 它有两个轴&#xff0c;一个是MainAxis 还有一个是交…

ubuntu下Thrift安装

thrift是一种常用rpc框架&#xff0c;工作中经常会用到&#xff0c;本文记录一下其安装过程。 目录 1.下载软件包 1.1thrift下载 1.2libevent下载 1.3boost下载 2.安装&#xff08;注意步骤&#xff09; 2.1安装libevent 2.2安装boost 2.3安装与Python2.7版本对应的py…

Vue+element Upload利用http-request实现第三方地址图片上传

vue第三方地址图片上传后端图片上传接口开发postman测试图片上传 Vue element (el-upload)中的:http-request图片上传java后端上传接口&#xff0c;利用OSS存储图片postman测试图片上传功能及方法 Vueelement Upload利用http-request实现第三方地址图片上传vue第三方地址图片上…

《花雕学AI》ChatGPT跟人类的思考方式有什么不同?

一、ChatGPT是一个基于GPT-3.5的对话语言模型&#xff0c;它可以根据用户的输入生成多轮对话&#xff0c;也可以生成文本、代码、音乐等内容。ChatGPT的思考方式是利用大量的数据和强大的算力来学习语言的联合概率分布&#xff0c;从而能够根据上下文和目标生成合理和有趣的回复…

python开发岗位需求分析,来看看它是什么一个情况吧

前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! 1.导入模块 import pandas as pd from pyecharts.charts import * from pyecharts import options as opts import matplotlib.pyplot as plt plt.rcParams[font.sans-serif][SimHei] plt.rcParams[axes.unicode_minus]Fal…

WRF模式与Python融合技术在多领域中的应用及精美绘图

当今从事气象及其周边相关领域的人员&#xff0c;常会涉及气象数值模式及其数据处理&#xff0c;无论是作为业务预报的手段、还是作为科研工具&#xff0c;掌握气象数值模式与高效前后处理语言是一件非常重要的技能。WRF作为中尺度气象数值模式的佼佼者&#xff0c;模式功能齐全…

目标检测算法——YOLOv5/v7/v8改进结合即插即用的动态卷积ODConv(小目标涨点神器)

&#x1f496;&#x1f496;>>>加勒比海带&#xff0c;QQ2479200884<<<&#x1f496;&#x1f496; &#x1f340;&#x1f340;>>>【YOLO魔法搭配&论文投稿咨询】<<<&#x1f340;&#x1f340; ✨✨>>>学习交流 | 温澜潮…

JS遍历数组里数组下的对象,根据数组中对象的某些值,组合成新的数组对象

前言: 大部分后端返回给前端的数据,是json形式的。里边包含了响应码,响应信息,有些还会返回数组对象等。现在有一个业务场景,我调用明细查询接口,返回的数据是一个对象数组的形式,但是我只需要对象中的某些属性值。这个时候我就需要想办法提取我所需要的值,然后组合成一…

吐血奉献精心整理的一大波数据集

80开源数据集资源汇总&#xff08;包含目标检测、医学影像、关键点检测、工业检测等方向&#xff09; 数据集下载汇总链接&#xff1a;https://www.cvmart.net/dataSets数据集将会不断更新&#xff0c;欢迎大家持续关注&#xff01; 小目标检测 AI-TOD航空图像数据集 数据集…

一例H-worm vbs脚本分析

样本的基本信息 MD5: c750a5bb8d9aa5a58c27956b897cf102 SHA1: e14994b9e32a3e267947cac36fb3036d9d22be21 SHA256: 1eb712d4976babf7c8cd0b34015c2701bc5040420e688911e6614b278cc82a42 SHA512: b1921c1dc7e8cf075882755be48c0d9636858cd7913188cb6c9ca6e7e741f049c143a79c683…

有哪些比较不错的AI绘画网站?

AI绘画最近非常流行。目前&#xff0c;互联网上有许多主流的人工智能绘画网站&#xff0c;但你都知道吗&#xff1f; 如果您正在寻找一个基于最先进智能技术的人工智能绘画网站&#xff0c;并介绍五个智能、现实和高效于一体的人工智能绘画网站。 1.即时AI 即时AI绘画是指通…

《RockectMQ实战与原理解析》Chapter4-分布式消息队列的协调者

4.1 NameServer 的功能 NameServer 是整个消息队列中的状态服务器&#xff0c;集群的各个组件通过它来了解全局的信息 。 同时&#xff0c;各个角色的机器都要定期向 NameServer 上报自己的状态&#xff0c;超时不上报的话&#xff0c; NameServer 会认为某个机器出故障不可用了…

基于OBS超低延迟直播实测(400毫秒左右)超多组图

阿酷TONY&#xff0c;原创文章&#xff0c;长沙。 文章简述&#xff1a;本文介绍使用OBS无延迟直播插件在第三方云平台&#xff0c;如何实现超低延时直播的完整教程&#xff08;延迟约为400毫秒左右&#xff0c;通常延迟是3-15秒&#xff09;。 OBS简要介绍 OBS&#xff08;O…