paddle 进行数字识别 (使用ocr数据集)

news2024/11/15 19:30:33

要点:

  • 喵了个喵,没使用 OCR
  • 参考文档: PaddleOCR数字仪表识别——2.数据合成及数据集制作_数字仪表数据集
  • https://blog.csdn.net/castlehe/category_10459202.html?spm=1001.2014.3001.5482

  • 最佳参考: 基于PaddleOCR的数字显示器字符识别 - 飞桨AI Studio

参考文档: 通过OCR实现验证码识别


通过OCR实现验证码识别

本篇将介绍如何通过飞桨实现简单的CRNN+CTC自定义数据集OCR识别模型,数据集采用CaptchaDataset中OCR部分的9453张图像,其中前8453张图像在本案例中作为训练集,后1000张则作为测试集。
在更复杂的场景中推荐使用PaddleOCR产出工业级模型,模型轻量且精度大幅提升。
同样也可以在PaddleHub中快速使用PaddleOCR。

一、环境配置

本教程基于PaddlePaddle 2.3.0 编写,如果你的环境不是本版本,请先参考官网安装 PaddlePaddle 2.3.0 。

import paddle
print(paddle.__version__)

二、自定义数据集读取器

常见的开发任务中,我们并不一定会拿到标准的数据格式,好在我们可以通过自定义Reader的形式来随心所欲读取自己想要数据。

设计合理的Reader往往可以带来更好的性能,我们可以将读取标签文件列表、制作图像文件列表等必要操作在__init__特殊方法中实现。这样就可以在实例化Reader时装入内存,避免使用时频繁读取导致增加额外开销。同样我们可以在__getitem__特殊方法中实现如图像增强、归一化等个性操作,完成数据读取后即可释放该部分内存。
需要我们注意的是,如果不能保证自己数据十分纯净,可以通过tryexpect来捕获异常并指出该数据的位置。当然也可以制定一个策略,使其在发生数据读取异常后依旧可以正常进行训练。

2.1 数据展示

点此快速获取本节数据集,待数据集下载完毕后可使用!unzip OCR_Dataset.zip -d data/命令或熟悉的解压软件进行解压,待数据准备工作完成后修改本文“训练准备”中的DATA_PATH = 解压后数据集路径

# 下载数据集 
!wget -O OCR_Dataset.zip https://bj.bcebos.com/v1/ai-studio-online/c91f50ef72de43b090298a38281e9c59a2d741eadd334f1cba7c710c5496e342?responseContentDisposition=attachment%3B%20filename%3DOCR_Dataset.zip&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2020-10-27T09%3A50%3A21Z%2F-1%2F%2Fddc4aebed803af6c57dac46abba42d207961b78e7bc81744e8388395979b66fa
# 解压数据集
!unzip OCR_Dataset.zip -d data/
import os

import PIL.Image as Image
import numpy as np
from paddle.io import Dataset

# 图片信息配置 - 通道数、高度、宽度
IMAGE_SHAPE_C = 3
IMAGE_SHAPE_H = 30
IMAGE_SHAPE_W = 70
# 数据集图片中标签长度最大值设置 - 因图片中均为4个字符,故该处填写为4即可
LABEL_MAX_LEN = 4


class Reader(Dataset):
    def __init__(self, data_path: str, is_val: bool = False):
        """
        数据读取Reader
        :param data_path: Dataset路径
        :param is_val: 是否为验证集
        """
        super().__init__()
        self.data_path = data_path
        # 读取Label字典
        with open(os.path.join(self.data_path, "label_dict.txt"), "r", encoding="utf-8") as f:
            self.info = eval(f.read())
        # 获取文件名列表
        self.img_paths = [img_name for img_name in self.info]
        # 将数据集后1024张图片设置为验证集,当is_val为真时img_path切换为后1024张
        self.img_paths = self.img_paths[-1024:] if is_val else self.img_paths[:-1024]

    def __getitem__(self, index):
        # 获取第index个文件的文件名以及其所在路径
        file_name = self.img_paths[index]
        file_path = os.path.join(self.data_path, file_name)
        # 捕获异常 - 在发生异常时终止训练
        try:
            # 使用Pillow来读取图像数据
            img = Image.open(file_path)
            # 转为Numpy的array格式并整体除以255进行归一化
            img = np.array(img, dtype="float32").reshape((IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W)) / 255
        except Exception as e:
            raise Exception(file_name + "\t文件打开失败,请检查路径是否准确以及图像文件完整性,报错信息如下:\n" + str(e))
        # 读取该图像文件对应的Label字符串,并进行处理
        label = self.info[file_name]
        label = list(label)
        # 将label转化为Numpy的array格式
        label = np.array(label, dtype="int32")

        return img, label

    def __len__(self):
        # 返回每个Epoch中图片数量
        return len(self.img_paths)

三、模型配置

3.1 定义模型结构以及模型输入

模型方面使用的简单的CRNN-CTC结构,输入形为CHW的图像在经过CNN->Flatten->Linear->RNN->Linear后输出图像中每个位置所对应的字符概率。考虑到CTC解码器在面对图像中元素数量不一、相邻元素重复时会存在无法正确对齐等情况,故额外添加一个类别代表“分隔符”进行改善。

网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。 

import paddle

# 分类数量设置 - 因数据集中共包含0~9共10种数字+分隔符,所以是11分类任务
CLASSIFY_NUM = 11

# 定义输入层,shape中第0维使用-1则可以在预测时自由调节batch size
input_define = paddle.static.InputSpec(shape=[-1, IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W],
                                   dtype="float32",
                                   name="img")

# 定义网络结构
class Net(paddle.nn.Layer):
    def __init__(self, is_infer: bool = False):
        super().__init__()
        self.is_infer = is_infer

        # 定义一层3x3卷积+BatchNorm
        self.conv1 = paddle.nn.Conv2D(in_channels=IMAGE_SHAPE_C,
                                  out_channels=32,
                                  kernel_size=3)
        self.bn1 = paddle.nn.BatchNorm2D(32)
        # 定义一层步长为2的3x3卷积进行下采样+BatchNorm
        self.conv2 = paddle.nn.Conv2D(in_channels=32,
                                  out_channels=64,
                                  kernel_size=3,
                                  stride=2)
        self.bn2 = paddle.nn.BatchNorm2D(64)
        # 定义一层1x1卷积压缩通道数,输出通道数设置为比LABEL_MAX_LEN稍大的定值可获取更优效果,当然也可设置为LABEL_MAX_LEN
        self.conv3 = paddle.nn.Conv2D(in_channels=64,
                                  out_channels=LABEL_MAX_LEN + 4,
                                  kernel_size=1)
        # 定义全连接层,压缩并提取特征(可选)
        self.linear = paddle.nn.Linear(in_features=429,
                                   out_features=128)
        # 定义RNN层来更好提取序列特征,此处为双向LSTM输出为2 x hidden_size,可尝试换成GRU等RNN结构
        self.lstm = paddle.nn.LSTM(input_size=128,
                               hidden_size=64,
                               direction="bidirectional")
        # 定义输出层,输出大小为分类数
        self.linear2 = paddle.nn.Linear(in_features=64 * 2,
                                    out_features=CLASSIFY_NUM)

    def forward(self, ipt):
        # 卷积 + ReLU + BN
        x = self.conv1(ipt)
        x = paddle.nn.functional.relu(x)
        x = self.bn1(x)
        # 卷积 + ReLU + BN
        x = self.conv2(x)
        x = paddle.nn.functional.relu(x)
        x = self.bn2(x)
        # 卷积 + ReLU
        x = self.conv3(x)
        x = paddle.nn.functional.relu(x)
        # 将3维特征转换为2维特征 - 此处可以使用reshape代替
        x = paddle.tensor.flatten(x, 2)
        # 全连接 + ReLU
        x = self.linear(x)
        x = paddle.nn.functional.relu(x)
        # 双向LSTM - [0]代表取双向结果,[1][0]代表forward结果,[1][1]代表backward结果,详细说明可在官方文档中搜索'LSTM'
        x = self.lstm(x)[0]
        # 输出层 - Shape = (Batch Size, Max label len, Signal) 
        x = self.linear2(x)

        # 在计算损失时ctc-loss会自动进行softmax,所以在预测模式中需额外做softmax获取标签概率
        if self.is_infer:
            # 输出层 - Shape = (Batch Size, Max label len, Prob) 
            x = paddle.nn.functional.softmax(x)
            # 转换为标签
            x = paddle.argmax(x, axis=-1)
        return x

四、训练准备

4.1 定义label输入以及超参数

监督训练需要定义label,预测则不需要该步骤。

# 数据集路径设置
DATA_PATH = "./data/OCR_Dataset"
# 训练轮数
EPOCH = 10
# 每批次数据大小
BATCH_SIZE = 16

label_define = paddle.static.InputSpec(shape=[-1, LABEL_MAX_LEN],
                                    dtype="int32",
                                    name="label")

4.2 定义CTC Loss

了解CTC解码器效果后,我们需要在训练中让模型尽可能接近这种类型输出形式,那么我们需要定义一个CTC Loss来计算模型损失。不必担心,在飞桨框架中内置了多种Loss,无需手动复现即可完成损失计算。

使用文档:CTCLoss

class CTCLoss(paddle.nn.Layer):
    def __init__(self):
        """
        定义CTCLoss
        """
        super().__init__()

    def forward(self, ipt, label):
        input_lengths = paddle.full(shape=[BATCH_SIZE],fill_value=LABEL_MAX_LEN + 4,dtype= "int64")
        label_lengths = paddle.full(shape=[BATCH_SIZE],fill_value=LABEL_MAX_LEN,dtype= "int64")
        # 按文档要求进行转换dim顺序
        ipt = paddle.tensor.transpose(ipt, [1, 0, 2])
        # 计算loss
        loss = paddle.nn.functional.ctc_loss(ipt, label, input_lengths, label_lengths, blank=10)
        return loss

4.3 实例化模型并配置优化策略

# 实例化模型
model = paddle.Model(Net(), inputs=input_define, labels=label_define)
# 定义优化器
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())

# 为模型配置运行环境并设置该优化策略
model.prepare(optimizer=optimizer,
                loss=CTCLoss())

五、开始训练

# 执行训练
model.fit(train_data=Reader(DATA_PATH),
            eval_data=Reader(DATA_PATH, is_val=True),
            batch_size=BATCH_SIZE,
            epochs=EPOCH,
            save_dir="output/",
            save_freq=1,
            verbose=1,
            drop_last=True)

六、预测前准备

6.1 像定义训练Reader一样定义预测Reader

# 与训练近似,但不包含Label
class InferReader(Dataset):
    def __init__(self, dir_path=None, img_path=None):
        """
        数据读取Reader(预测)
        :param dir_path: 预测对应文件夹(二选一)
        :param img_path: 预测单张图片(二选一)
        """
        super().__init__()
        if dir_path:
            # 获取文件夹中所有图片路径
            self.img_names = [i for i in os.listdir(dir_path) if os.path.splitext(i)[1] == ".jpg"]
            self.img_paths = [os.path.join(dir_path, i) for i in self.img_names]
        elif img_path:
            self.img_names = [os.path.split(img_path)[1]]
            self.img_paths = [img_path]
        else:
            raise Exception("请指定需要预测的文件夹或对应图片路径")

    def get_names(self):
        """
        获取预测文件名顺序 
        """
        return self.img_names

    def __getitem__(self, index):
        # 获取图像路径
        file_path = self.img_paths[index]
        # 使用Pillow来读取图像数据并转成Numpy格式
        img = Image.open(file_path)
        img = np.array(img, dtype="float32").reshape((IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W)) / 255
        return img

    def __len__(self):
        return len(self.img_paths)

6.2 参数设置

# 待预测目录 - 可在测试数据集中挑出\b3张图像放在该目录中进行推理
INFER_DATA_PATH = "./sample_img"
# 训练后存档点路径 - final 代表最终训练所得模型
CHECKPOINT_PATH = "./output/final.pdparams"
# 每批次处理数量
BATCH_SIZE = 32

6.3 展示待预测数据

import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
sample_idxs = np.random.choice(50000, size=25, replace=False)

for img_id, img_name in enumerate(os.listdir(INFER_DATA_PATH)):
    plt.subplot(1, 3, img_id + 1)
    plt.xticks([])
    plt.yticks([])
    im = Image.open(os.path.join(INFER_DATA_PATH, img_name))
    plt.imshow(im, cmap=plt.cm.binary)
    plt.xlabel("Img name: " + img_name)
plt.show()

七、开始预测

飞桨2.3 CTC Decoder 相关API正在迁移中,本节暂时使用简易版解码器。

# 编写简易版解码器
def ctc_decode(text, blank=10):
    """
    简易CTC解码器
    :param text: 待解码数据
    :param blank: 分隔符索引值
    :return: 解码后数据
    """
    result = []
    cache_idx = -1
    for char in text:
        if char != blank and char != cache_idx:
            result.append(char)
        cache_idx = char
    return result


# 实例化推理模型
model = paddle.Model(Net(is_infer=True), inputs=input_define)
# 加载训练好的参数模型
model.load(CHECKPOINT_PATH)
# 设置运行环境
model.prepare()

# 加载预测Reader
infer_reader = InferReader(INFER_DATA_PATH)
img_names = infer_reader.get_names()
results = model.predict(infer_reader, batch_size=BATCH_SIZE)
index = 0
for text_batch in results[0]:
    for prob in text_batch:
        out = ctc_decode(prob, blank=10)
        print(f"文件名:{img_names[index]},推理结果为:{out}")
        index += 1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/420325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot ElasticSearch 【SpringBoot系列16】

SpringCloud 大型系列课程正在制作中,欢迎大家关注与提意见。 程序员每天的CV 与 板砖,也要知其所以然,本系列课程可以帮助初学者学习 SpringBooot 项目开发 与 SpringCloud 微服务系列项目开发 elasticsearch是一款非常强大的开源搜索引擎&a…

Logstash:部署和扩展 Logstash

Elastic Stack 用于大量用例,从操作日志和指标分析到企业和应用程序搜索。 确保你的数据可扩展、持久且安全地传输到 Elasticsearch 非常重要,尤其是对于任务关键型环境。 本文档的目的是强调 Logstash 最常见的架构模式以及如何随着部署的增长而有效扩…

c++学习之c++对c的扩展2

目录 1.c/c中的const 1 const概述 2 c/c中const的区别 c中的: c中的const: c/c中的const异同 c中const修饰的变量,分配内存情况 尽量以const替换define 2.引用 函数的引用: 引用的本质 指针的引用 5 常量引用 内联函数 内联函数…

(排序7)归并排序(递归)

归并排序 归并排序采用的是两个有序数组的归并。比如说现在想让一个数组有序。之前我们讲过,如果说你现在有两个有序数组的话,那么我们就可以把这两个有序数组给他合并成一个有序数组。两个有序区间归并的思路其实很简单(这个也是归并的单趟…

Android 自定义View 之 计时文字

计时文字前言正文一、XML样式二、构造方法三、API方法四、使用五、源码前言 在Android开发中,常常会有计时的一些操作,例如收验证码的时候倒计时,秒表的计时等等,于是我就有了一个写自定义View的想法,本文效果图。 正文…

Vue2-黑马(八)

目录: (1)router-动态路由 (2)router-重置路由 (3)router-页面刷新 (1)router-动态路由 我们有这样一个需求,不同的用户根据自己的身份不一样,…

Seaborn 数据可视化基础

目录 介绍 知识点 Seaborn 介绍 快速优化图形 Seaborn 绘图 API 一、散点图: 参数hue hue hue_order 参数style 二 、线形图 三、类别图 绘制箱线图 绘制小提琴图 绘制增强箱线图 绘制点线图 绘制条形图 绘制计数条形图 四、分布图 五、回归图 …

nginx配置

单线程应用 稳定性高 系统资源消耗低 线程切换消耗小 对HTTP并发连接处理能力高 单台服务器可支持2w个并发请求 nginx与apache区别 Nginx相对于Apache的优点: 轻量级,同样是 web 服务,比Apache 占用更少的内存及资源,高并发&#xff0…

攻防世界-file_include(convert.iconv的使用)

代码审计,存在文件包含,直接上伪协议 发现不行,应该是存在字符过滤 知识盲区: 1.file://协议,需要填写绝对路径,只能读取txt文件,后面直接跟绝对路径。 file:///etc/passwd 2.php://filter …

深入浅出 Golang 内存管理

了解内存管理~ 前言: 本节课主要介绍了内存管理知识与自动内存管理机制,并对目前 Go 内存管理过程中存在的问题提出了解决方案,同时结合了上次课程学习的《Go 语言性能优化》相关知识,提供可行性的优化建议 … 自动内存管理 Go…

spring-boot怎么扫描不在启动类所在包路径下的bean

前言: 项目中有多个模块,其中有些模块的包路径不在启动类的子路径下,此时我们怎么处理才能加载到这些类; 1 使用SpringBootApplication 中的scanBasePackages 属性; SpringBootApplication(scanBasePackages {"com.xxx.xx…

C++linux高并发服务器项目实践 day5

Clinux高并发服务器项目实践 day5程序和进程单道、多道程序设计时间片并行和并发进程控制块(PCB)进程状态转换进程的状态进程相关命令进程号和相关函数进程创建父子进程的关系GDB多进程调试程序和进程 程序是包含一系列信息的文件,这些信息描…

你知道怎么实现定时任务吗?

诸位读者都知道笔者写东西都是用到才写,笔者的学习足迹自从参加工作之后就是 非系统 学习了,公司里源代码只要有笔者不知道的技术细节,笔者就会仔细的研究清楚,笔者是不喜欢给自己留下问题的那种学习习惯。 为何要写 笔者最近负…

如何使用Thymeleaf给web项目中的网页渲染显示动态数据?

编译软件:IntelliJ IDEA 2019.2.4 x64 操作系统:win10 x64 位 家庭版 服务器软件:apache-tomcat-8.5.27 目录一. 什么是Thymeleaf?二. MVC2.1 为什么需要MVC?2.2 MVC是什么?2.3 MVC和三层架构之间的关系及工…

AI绘图体验:想象力无限,创作无穷!(文生图)

基础模型:3D二次元 PIXEL ART (1)16-bit pixel art, outside of caf on rainy day, light coming from windows, cinematic still(电影剧照), hdr (2) 16-bit pixel art, island in the clouds, by studio ghibli(吉卜力工作室…

配置基于WSL2的Docker环境并支持CUDA

导言 Content 正如前文windows 10 开启WSL2介绍的,我们可以在windows10中使用linux子系统。今天本文介绍如何在此基础上安装Docker并支持在wsl中使用GPU。 准备工作 加入windows insider preview。建议选Dev通道,不要选Beta。 安装Nvidia WSL2-compa…

【数据结构】-计数排序

🎇作者:小树苗渴望变成参天大树 🎉 作者宣言:认真写好每一篇博客 🎊作者gitee:link 如 果 你 喜 欢 作 者 的 文 章 ,就 给 作 者 点 点 关 注 吧! 文章目录前言一、计数排序二、排序算法复杂度…

Nginx网站服务配置

一、Nginx概述 1.1 Nginx概述 Nginx: Nginx 是开源、高性能、高可靠的 Web 和反向代理服务器,而且支持热部署,几乎可以做到 7 * 24 小时不间断运行,即使运行几个月也不需要重新启动,还能在不间断服务的情况下对软件…

分布式计算技术(上):经典计算框架MapReduce、Spark 解析

当一个计算任务过于复杂不能被一台服务器独立完成的时候,我们就需要分布式计算。分布式计算技术将一个大型任务切分为多个更小的任务,用多台计算机通过网络组装起来后,将每个小任务交给一些服务器来独立完成,最终完成这个复杂的计…

07 -全局状态管理

全局状态管理 7-1:开篇 在上一章中我们完成了 “一半” 的文章搜索功能,并且留下了一些问题。那么这些历史残留的问题,我们将会在本章节中通过 全局状态管理工具 进行处理。 那么究竟什么是 全局状态管理工具,如何在 uniapp 中…