OCR识别网络CRNN理解与Pytorch实现

news2024/9/22 15:35:21

CRNN是2015年的论文“An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition”提出的图像字符识别网络,也是目前工业界使用较为广泛的一个OCR网络。论文地址:https://arxiv.org/abs/1507.05717

1. 网络结构

CRNN是一个端到端可训练的网络,并且可处理任意长度的字符序列。CRNN得名于Convolutional Recurrent Neural Network,从名称即可看出,该网络包含了卷积网络和递归网络。实际上,CRNN由三部分组成,分别是卷积层部分(Convolutional layers)、递归层部分(Recurrent Layers)和转录层部分(Transcription Layers),如下图所示:

其中, 卷积层的作用是从输入图像中提取特征,递归层则对卷积层输出的feature maps进行预测,最后,转录层将递归层的预测结果翻译成文字标签序列。CNN和RNN可由同一个损失函数进行联合训练。

在图像输入CRNN之前,需要缩放到指定高度height,宽度无限制。卷积层输出的feature maps在送入RNN之前,从左到右生成一个feature vector序列,第i个feature vector为feature maps第i列的元素的级联。这样做的好处是,每个feature vector代表了原图像上一个矩形区域的特征(感受野),使得网络能够预测不同长度的字符序列。

RNN网络的优势在于,它能够有效利用序列的上下文信息进行预测,比分别预测单个字符有更好的精确度和稳定性,同时,它对输入序列的长度无限制,比单纯使用CNN网络更加灵活。

由于传统RNN存在梯度爆炸和梯度消失问题,因此,在这篇文章中,作者采用了LTSM(Long-Short Term Memory)来克服该问题。一个LSTM包含一个记忆单元(Memory Cell)和三个乘法门(Multiplicative gates),分别为输入门(input gate)、输出门(output gate)和遗忘门(forget gate),如下图所示:

由于基于图像的文字识别具有较强的前向和后向上下文信息,因此,使用双向LSTM(bidirectional LSTM)是一个合适的选择。 

转录层将RNN层的预测结果(用概率表示)映射到字符序列。 在实践中,存在两种转录模式,分别是基于词典的转录,和无词典转录。在基于词典的模式中,会选择词典中最高概率的标签进行预测;而无词典模式,预测则是在无任何词典的情况下进行的。

CRNN的具体网络结构及配置如下:

2. 代码实现 

网上找到一个CRNN的Pytorch实现,亲测好用,代码链接:CRNN Pytorch

网络定义:

import torch.nn as nn

class CRNN(nn.Module):

    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]

        # rnn features
        output = self.rnn(conv)

        return output

 其中,Bidirectional LSTM的定义如下:

class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output

Demo是基于字典的转录方式,可以识别0~9的1-0个数字,以及a~z的26个字母。

import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image

import models.crnn as crnn


model_path = './data/crnn.pth'  # 与训练模型路径
img_path = './data/demo.png'    # 测试图片路径
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'    # 字典

model = crnn.CRNN(32, 1, 37, 256)
if torch.cuda.is_available():
    model = model.cuda()
print('loading pretrained model from %s' % model_path)
model.load_state_dict(torch.load(model_path))

converter = utils.strLabelConverter(alphabet)    # 定义字典转录函数

transformer = dataset.resizeNormalize((100, 32))   # 图像预处理函数
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available():
    image = image.cuda()
print('image size: ', image.shape)
image = image.view(1, *image.size())
image = Variable(image)

model.eval()
preds = model(image)    # CRNN预测

_, preds = preds.max(2)    # 找到最大概率所对应的index
preds = preds.transpose(1, 0).contiguous().view(-1)

preds_size = Variable(torch.IntTensor([preds.size(0)]))
raw_pred = converter.decode(preds.data, preds_size.data, raw=True)   # 逐一输出预测字符,如a-----v--a-i-l-a-bb-l-e---
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)  # 输出最终预测结果,如available
print('%-20s => %-20s' % (raw_pred, sim_pred))

 demo执行结果:a-----v--a-i-l-a-bb-l-e--- => available  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1398249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SCI好看的配图-汇总

文章目录 图源:Sustainable Cities and Society【期刊】条形图2热力图-地图 图源:Sustainable Cities and Society【期刊】 引自:A machine learning-driven spatio-temporal vulnerability appraisal based on socio-economic data for COV…

【Vue】使用 Vuex 作为状态管理

【Vue】使用 Vuex 作为状态管理 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式和库。它使用单一状态树,这意味着这个对象包含了全部的应用层级状态,并且以一种相对集中的方式存在。这也意味着,通常单个项目中只有一个 Vuex store。Vue…

AI大模型开发架构设计(2)——AI绘画技术架构应用实践

文章目录 1 AI绘画整体流程2 AI绘画技术架构文生图核心算法原理文生图工程架构 3 AI绘画的应用实践 1 AI绘画整体流程 第一步:输入 Prompt 提示词:/mj 提示词第二步:文生图(Text-to-Image)构图第三步:图片渲染第四步:…

代码里下毒了,支付下单居然没加幂等

又是一个风和日丽没好的一天,小猫戴着耳机,安逸地听着音乐,撸着代码,这种没有会议的日子真的是巴适得板。 不料祸从天降,组长火急火燎地跑过来找到了小猫。“快排查一下,目前有A公司用户反馈积分被多扣了”…

【咕咕送书 | 第八期】羡慕同学进了大厂核心部门,看懂这本书你也能行!

🎬 鸽芷咕:个人主页 🔥 个人专栏:《linux深造日志》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! ⛳️ 写在前面参与规则 ✅参与方式:关注博主、点赞、收藏、评论,任意评论(每人最多评论…

100天精通鸿蒙从入门到跳槽——第8天:TypeScript 知识储备:泛型

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通Golang》…

实战纪实 | 某配送平台zabbix 未授权访问 + 弱口令

本文由掌控安全学院 - 17828147368 投稿 找到一个某src的子站,通过信息收集插件wappalyzer,发现ZABBIX-监控系统: 使用谷歌搜索历史漏洞:zabbix漏洞 通过目录扫描扫描到后台,谷歌搜索一下有没有默认弱口令 成功进去了…

nginx配置内网代理,前端+后端分开配置

安装好后nginx,进入配置文件 我这块安装在了home里面,各位根据自身情况选择 打开nginx.conf文件 在底部查看是否包含这段信息:含义是配置文件包含该路径下的配置文件 include /home/nginx/conf/conf.d/*.conf; # 该路径根据自己的安装位置…

【从0到1学Python】第二讲:Python中的各种“量”(一)

也许你知道学习一门语言的第一件事就是在屏幕上输出"Hello world!"。 但是请别着急!在本系列文章中,我希望在讲如何输出之前,先谈谈Python中的各种量。因为,输出、输入语句也是基于各种“量”来完成的。我想&#xff0c…

基于springboot+vue的宠物领养系统(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容:毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 背景及意…

Three.JS教程1 环境搭建、场景与相机

Three.JS教程1 环境搭建、场景与相机 一、Three.JS简介二、环境搭建1. 开发准备2. 安装 three.js3. 新建文件index.htmlmain.js 4. 关于附加组件5. 启动 三、创建场景1. 场景的概念2. 相机的概念3. 相机的几个相关概念(1)视点(Position&#…

【机器学习】四大类监督学习_模型选择与模型原理和场景应用_第03课

监督学习中模型选择原理及场景应用 监督学习应用场景 文本分类场景: o 邮件过滤:训练模型识别垃圾邮件和非垃圾邮件。 o 情感分析:根据评论或社交媒体内容的情感倾向将其分类为正面、负面或中性评价。 o 新闻分类:将新闻文章自动…

第一篇【传奇开心果】Vant 开发移动应用:从helloworld开始

传奇开心果系列博文 博文系列目录Vant of Vue 开发移动应用示例博文目录一、从helloworld开始二、添加几个常用组件三、添加组件事件处理四、添加页面和跳转切换路由五、归纳总结知识点六、知识点示例代码 博文系列目录 Vant of Vue 开发移动应用示例 博文目录 一、从hellow…

二、简单控件

二、简单控件 #mermaid-svg-TR8KwIeb54zOjfmt {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-TR8KwIeb54zOjfmt .error-icon{fill:#552222;}#mermaid-svg-TR8KwIeb54zOjfmt .error-text{fill:#552222;stroke:#55222…

45 mount 文件系统

前言 在 linux 中常见的文件系统 有很多, 如下 基于磁盘的文件系统, ext2, ext3, ext4, xfs, btrfs, jfs, ntfs 内存文件系统, procfs, sysfs, tmpfs, squashfs, debugfs 闪存文件系统, ubifs, jffs2, yaffs 文件系统这一套体系在 linux 有一层 vfs 抽象, 用户程序不用…

1.php开发-个人博客项目文章功能显示数据库操作数据接收

(2022-day12) #知识点 1-php入门,语法,提交 2-mysql 3-HTMLcss ​ 演示案例 博客-文章阅读功能初步实现 实现功能: 前端文章导航,点入内容显示,更改ID显示不同内容 实现步骤&#xff1…

04 MyBatisPlus之逻辑删除+锁+防全表更新/删除+代码生成插件

1 逻辑删除 1. 1 什么是逻辑删除 , 以及逻辑删除和物理删除的区别? 逻辑删除,可以方便地实现对数据库记录的逻辑删除而不是物理删除。逻辑删除是指通过更改记录的状态或添加标记字段来模拟删除操作,从而保留了删除前的数据,便于后续的数据…

P1059 [NOIP2006 普及组] 明明的随机数————C++、Python

目录 [NOIP2006 普及组] 明明的随机数题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示 解题思路Code——CCode——Python运行结果 [NOIP2006 普及组] 明明的随机数 题目描述 明明想在学校中请一些同学一起做一项问卷调查,为了实验的客观性&#xff0…

uniapp的IOS证书(.p12)和描述文件(.mobileprovision)申请 2024年最新教程

文章目录 准备环境登录 iOS Dev Center 下面我们从头开始学习一下如何申请开发证书、发布证书及相对应的描述文件。首先需要申请苹果 App ID (App的唯一标识)生成证书请求文件申请开发(Development)证书和描述文件申请开发(Development)证书添加调试设备…

免费200万Tokens 用科大讯飞API调用星火大模型服务

简介 自ChatGPT火了之后,国内的大模型发展如雨后春笋。其中的佼佼者之一就是科大讯飞研发的星火大模型,现在大模型已经更新到V3 版本,而且对开发者也是相当友好,注册就送200万tokens,讯飞1tokens 约等于 1.5 个中文汉字 或者 0.8 个英文单词…