使用ctcloss训练矩阵生成目标字符串

news2025/1/15 12:51:54

首先我们需要明确 c t c l o s s ctcloss ctcloss是用来做什么的。比如说我们要生成的目标字符串长度为 l l l,而这个字符串包含 k k k个字符,字符串允许的最大长度为 L L L,这里我们认为一个位置是一个时间步,就是一拍,记为 T T T
对于这个允许最大长度,需要做出一些解释,我们需要定义一个生成字符串的规则,因为训练的时候,这个标签的长度是不一样的,所以我们需要引入空格来生成字符串,那么相应的,关于空格定义以下两条规则:

  1. 空格与空格之间的字符串是可以去掉重复字母的
  2. 使用空格间隔的两个部分串不能去重,比如说这个串长成:cc cc,在运用上述两条规则之后应该变成c c

举例来说,对于目标生成串如果是 C A T CAT CAT的话,那么在时间拍为 5 5 5拍的情况下,他有这以下 28 28 28条路径可以生成 C A T CAT CAT
在这里插入图片描述
注:上述图片引自这里,博主对这篇文章加以致谢,还好有这个文章让我对 c t c l o s s ctcloss ctcloss有了初步的认识。

因此我们首先拿在手里的是一个随机矩阵为 y y y,这个矩阵的形状是 [ k , T ] [k,T] [k,T],其中 y [ i , j ] y[i,j] y[i,j]表示的是在第 j j j个时间步,该字符为 i i i的概率,而我们需要做的是训练这个 y y y矩阵,让他最终产生指定字符串的概率 p p p最大,所以我们设置 − l n ( p ) -ln(p) ln(p)为损失函数,我们的目标就是让这个损失函数最小。
那么我们应该怎么做呢?你可以枚举出这28条路径的全部概率, l i k e   t h i s like\space this like this,然后把他们相加之后求损失和。

在这里插入图片描述
但是我知道你一定不想这么做。
所以呢,我们需要使用一种更加简洁的方法来求这个概率,咋做呢?这里就要放上耳熟能详的图了?
![在这里插入图片描述](https://img-blog.csdnimg.cn/b8d163baf95f40ea8f4ad8cd78c383fd.png

通过这个图,你可以看出来,我们对字符串进行了插入空格的操作,没错,因为我一开始的时候不给他插入空格行不行,但是你必须得考虑每一步能不能取到空格,并且状态转移的时候,还要考虑从第几个空格转移过来,非常麻烦,不如直接插入空格。
没错就是动态规划,我们来分析以下状态转移方程:
首先在 t t t时刻,能取到 s s s字符,那么这个字符可以由 t − 1 t-1 t1时刻的 s s s字符转移过来,两个一样的字符消去就行了呗。与此同时,该字符是不是还可以由 s − 1 s-1 s1字符转移过来 ? ? ?就是 _ C _ A _ T _ \_C\_A\_T\_ _C_A_T_ T T T接在 _ \_ _后面的情况。我们再考虑从第 s − 2 s-2 s2个字符转移过来,这个时候第 s s s个字符和第 s − 2 s-2 s2个字符必须不能相同,否则的话就是 _ C _ A _ A _ T _ \_C\_A\_A\_T\_ _C_A_A_T_的两个 A A A越过中间的空格连在一起,这不铁定消去了。
所以状态转移使用如下的方法实现:

alpha[s, t] = alpha[s, t - 1]
if s - 1 >= 0:
   alpha[s, t] += alpha[s - 1, t - 1]
if s - 2 >= 0 and blank_label[s] != '0' and blank_label[s] != blank_label[s - 2]:
   alpha[s, t] += alpha[s - 2, t - 1]
alpha[s, t] *= y[map_dict[blank_label[s]], t]

但是你一定会问弄这个动态规划矩阵有个锤子用,我们来看看:
对于 a l p h a alpha alpha矩阵当中的任意一个元素来说,我们可以得到以下的表达式,其中 l l l是任意一个可能产生字符串的路径, π \pi π是全部路径, l t l_t lt代表这个路径上在第 t t t拍上的字符, P P P为概率。
a l p h a [ i , t ] = P ( l t ) ∑ l ∈ π ∏ t ′ = 1 t − 1 P ( l t ′ ) alpha[i,t] = P(l_t)\sum_{l\in \pi} \prod_{t'=1}^{t-1} P(l_{t'}) alpha[i,t]=P(lt)lπt=1t1P(lt)
那么我们如果拿到一个后向传播的矩阵 b e t a beta beta,是不是就能得到一个:
b e t a [ i , t ] = P ( l t ) ∑ l ∈ π ∏ t ′ = t + 1 T P ( l t ′ ) beta[i,t] = P(l_t)\sum_{l\in \pi} \prod_{t'=t+1}^{T} P(l_{t'}) beta[i,t]=P(lt)lπt=t+1TP(lt)
然后就有:
a l p h a [ i , t ] ∗ b e t a [ i , t ] = P ( l t ) ∑ l ∈ π ∏ t ′ = 1 T P ( l t ′ ) alpha[i,t]*beta[i,t] = P(l_t)\sum_{l\in \pi} \prod_{t'=1}^{T} P(l_{t'}) alpha[i,t]beta[i,t]=P(lt)lπt=1TP(lt)
所以我们非常想求的总概率 p = ∑ l ∈ π ∏ t ′ = 1 T P ( l t ′ ) p=\sum_{l\in \pi} \prod_{t'=1}^{T} P(l_{t'}) p=lπt=1TP(lt)就可以使用 a l p h a [ i , t ] ∗ b e t a [ i , t ] P ( l t ) \frac{alpha[i,t]*beta[i,t]}{P(l_t)} P(lt)alpha[i,t]beta[i,t]来表示。
公式推导鸣谢:这里
注:这只是我大概的理解,不能十分完备的使用原文章中的符号
接下来就到了非常鸡冻人心的训练过程,差点没给我训练死了。因为改了一天这个梯度公式,并且我体会了什么是梯度消失, s o f t m a x softmax softmax的作用,接下来将详细记录我训练的这个过程,应该只是我记得的了,其中非常感谢这几篇文章的帮助,尤其是在晚上八点还是没有结果的时候看到的这篇文章,但是当时通过死亡调试梯度矩阵已经反应过来是梯度问题了hh。
首先先声明一下:我不能完全保证我的梯度求解没有问题,但是的确训练出了结果,并且参考多篇博客,梯度的结果全都不一样,因此我只能找到一个我认为最合理的梯度来进行梯度下降。
那么我们再来捋一下思路:首先我们想的是要最小化 − l n ( p ) -ln(p) ln(p),而前面我们又求出 p = a l p h a [ i , t ] ∗ b e t a [ i , t ] P ( l t ) p=\frac{alpha[i,t]*beta[i,t]}{P(l_t)} p=P(lt)alpha[i,t]beta[i,t],这里我们将 P ( l t ) P(l_t) P(lt)换成 y k t y_{k}^{t} ykt来表示,就是在 t t t时刻的第 k k k个字符的概率。我们这里是想对 y y y求偏导,但是这里有一个 b u g bug bug就是这个 y y y他不一定是每一列的和都是 1 1 1,所以我们需要对他进行 s o f t m a x softmax softmax操作,因此,我们用 x k t x_{k}^{t} xkt代表经过 s o f t m a x softmax softmax之后的。困了,明天再写

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/914995.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PyTorch三种主流模型构建方式:nn.Sequential、nn.Module子类、nn.Module容器开发实践,以真实烟雾识别场景数据为例

Keras和PyTorch是两个常用的深度学习框架,它们都提供了用于构建和训练神经网络的高级API。 Keras: Keras是一个高级神经网络API,可以在多个底层深度学习框架上运行,如TensorFlow和CNTK。以下是Keras的特点和优点: 优点&#xff…

解决git上传远程仓库时的最大文件大小限制

git默认限制最大的单文件100M,当某个文件到达50M时会给你提示。解决办法如下 首先,打开终端,进入项目所在的文件夹; 输入命令:git config http.postBuffer 524288000 执行完上面的语句后输入:git config…

Stable Diffusion 系列教程 | 图生图基础

前段时间有一个风靡全网的真人转漫画风格,受到了大家的喜欢 而在SD里,就可以通过图生图来实现类似的效果 当然图生图还有更好玩的应用,我们一点一点来探索 首先我们来简单进行一下图生图的这一个实践---真人转动漫 1. 图生图基本界面 和…

代码之美:探索可维护性的核心与实践

为什么可维护性如此重要 项目的长期健康 在软件开发的早期阶段,团队可能会对代码的可维护性不太重视,因为他们更关心的是功能的快速交付。但随着时间的推移,随着代码库的增长和复杂性的增加,不重视代码的可维护性可能会导致严重的…

docker使用安装教程

docker使用安装教程 一、docker安装及下载二、使用教程2.1 镜像2.2 容器2.3 docker安装Redis 一、docker安装及下载 一、安装 安装执行命令:curl -fsSL https://get.docker.com | bash -s docker --mirror Aliyun 二、启停常用命令 启动docker,执行命令&#xf…

分支和循环语句-C语言(初阶)

目录 一、什么是语句 二、分支语句 2.1 if语句 2.2 switch语句 三、循环语句 3.1 while循环 3.2 for循环 3.3 do...while循环 一、什么是语句 C语言语句有五类:表达式语句、函数调用语句、控制语句、复合语句、空语句。 控制语句用于控制程序的执行流程&#xff0…

在vue3+ts+vite中使用svg图片

目录 前言 步骤 1.安装svg-sprite-loader,这里使用的是6.0.11版本 2.项目的svg图片存放在src/icons下,我们在这里创建两个文件index.ts和index.vue(在哪创建和文件名字并没有任何要求) 3.在index.ts中加入下列代码(如果报错找不到fs模块请…

Redis的基本操作

文章目录 1.Redis简介2.Redis的常用数据类型3.Redis的常用命令1.字符串操作命令2.哈希操作命令3.列表操作命令4.集合操作命令5.有序集合操作命令6.通用操作命令 4.Springboot配置Redis1.导入SpringDataRedis的Maven坐标2.配置Redis的数据源3.编写配置类,创还能Redis…

ubuntu修改默认文件权限umask

最近在使用ubuntu的过程中发现一个问题: 环境是AWS EC2,登录用户ubuntu,系统默认的umask是027,修改/etc/profile文件中umask 027为022后,发现从ubuntu用户sudo su过去root用户登录查询到的umask还是027,而…

2023-8-22 单调栈

题目链接&#xff1a;单调栈 #include <iostream>using namespace std;const int N 100010;int n; int stk[N], tt;int main() {cin >> n;for(int i 0; i < n; i ){int x;cin >> x;while(tt && stk[tt] > x) tt--;if(tt) cout << st…

第十章,搜索模块

10.1添加搜索框 <template><div class="navbar-form navbar-left hidden-sm"><div class="form-group"><inputv-model.trim="value"type="text"class="form-control search-input mac-style"placeho…

数据传输过程

2 数据传输过程 了解网络中常用的分层模型后&#xff0c;现在来学习一下数据在各层之间是如何传输的。 2.1数据封装与解封装过程(一) 下面我们将以TCP/IP五层结构为基础来学习数据在网络中传输的“真相”。由于这个过程比较 抽象&#xff0c;我们可以类比给远在美国的朋友邮寄…

人工智能深度估计技术

人工智障&#xff08;能&#xff09;走起&#xff01;&#xff01;&#xff01; 下面是基本操作&#xff1a; 在Hugging Face网页中找到Depth Estimation的model&#xff0c;如下图&#xff1a; Hugging Face – The AI community building the future. &#xff08;上Huggin…

从自动驾驶到智能助理:AI和ML技术的革命性应用与前景

人工智能&#xff08;AI&#xff09;和机器学习&#xff08;ML&#xff09;的快速发展正在改变我们的世界。它们以惊人的速度渗透到各个领域&#xff0c;从自动驾驶汽车到智能助理、语音识别和自然语言处理等。AI和ML技术的应用范围和影响力越来越广泛&#xff0c;为我们的日常…

SpringMVC拦截器学习笔记

SpringMVC拦截器 拦截器知识 拦截器(Interceptor)用于对URL请求进行前置/后置过滤 Interceptor与Filter用途相似但实现方式不同 Interceptor底层就是基于Spring AOP面向切面编程实现 拦截器开发流程 Maven添加依赖包servlet-api <dependency><groupId>javax.se…

【Rust】Rust学习 第十八章模式用来匹配值的结构

模式是 Rust 中特殊的语法&#xff0c;它用来匹配类型中的结构&#xff0c;无论类型是简单还是复杂。结合使用模式和 match 表达式以及其他结构可以提供更多对程序控制流的支配权。模式由如下一些内容组合而成&#xff1a; 字面值解构的数组、枚举、结构体或者元组变量通配符占…

CSS笔记

介绍 CSS导入方式 三种方法都将文字设置成了红色 CSS选择器 元素选择器 id选择器 图中div将颜色控制为红色&#xff0c;#name将颜色控制为蓝色&#xff0c;谁控制的范围最小&#xff0c;谁就生效&#xff0c;所以第二个div是蓝色的。id属性值要唯一&#xff0c;否则报错。 clas…

【STM32RT-Thread零基础入门】 6. 线程创建应用(线程挂起与恢复)

硬件&#xff1a;STM32F103ZET6、ST-LINK、usb转串口工具、4个LED灯、1个蜂鸣器、4个1k电阻、2个按键、面包板、杜邦线 文章目录 前言一、RT-Thread相关接口函数1. 挂起线程2. 恢复线程 二、程序设计1. car_led.c2.car_led.h3. main.c 三、程序测试总结 前言 在上一个任务中&a…

Mysql group by使用示例

文章目录 1. groupby时不能查询*2. 查询出的列必须在group by的条件列中3. group by多个字段&#xff0c;这些字段都有索引也会索引失效&#xff0c;只有group by单个字段索引才能起作用4. having条件必须跟group by相关联5. 用group by做去重6. 使用聚合函数做数量统计7. havi…

ShardingSphere02-MySQL主从同步配置

1、MySQL主从同步原理 基本原理&#xff1a; slave会从master读取binlog来进行数据同步 具体步骤&#xff1a; step1&#xff1a;master将数据改变记录到二进制日志&#xff08;binary log&#xff09;中。step2&#xff1a; 当slave上执行 start slave 命令之后&#xff0c…