视觉 → 检测提取

news2025/1/12 23:35:11

目标检测任务非常有趣且具有挑战性。有些任务非常复杂,需要更多数据才能有所产出。但在这篇文章中,我将展示一个符号检测的小任务,它可以用更少的数据完成。该项目的目的是使用计算机视觉技术从一组给定的图像中提取文本并检测各种符号。

在这个任务中我们需要解决:

  1. 用于训练和推理的 torch (SSD MobileNet) 模型。

  2. 用于可视化的 cv2 和 matplotlib。

  3. 用于数据处理和存储的 NumPy 和 pandas。

  4. NVIDIA Tesla T4 或 NVIDIA Tesla P100 GPU 配置。

使用 SSD MobileNet 进行符号检测

数据

数据非常少,我们总共有 7 个图像和 9 个符号(类),即我们的任务是在符号检测的基础上进行多分类。这将是一项具有挑战性的任务,因为与处理大数据集不同,我们正在用更少的数据解决问题,其中很少有如下:

a615ac9881f36d81768eb5386bf96461.png

我们从中提取文本和所有符号的示例图像

797d650b29ae6fc5a6b356ff72f6f421.png

示例符号

首先为我们的项目制定一个计划:

  1. 由于我们的数据较少,我们将使用一些图像增强技术并尝试生成数据。

  2. 我们将使用 SSD MobileNet 模型来训练我们的数据集。

  3. 我们将在训练中使用所有这些合成数据,并将使用其余六张原始图像作为验证集进行验证。

  4. 在我们的最后一步中,我们将使用 Paddle-OCR 从图像中提取文本。

  5. 文本详细信息包括设备名称、参考号、批号、数量等,这些详细信息将通过 excel 文件生成。

什么是 SSD MobileNet

在深入研究 SSD MobileNet 之前,让我们先了解什么是 SSD。

  • SSD 全称 Single Shot Detection,用于实时检测物体。

  • SSD 的两个主要组件是 Backbone model 和 SSD Head。

  • 其中 Backbone model 通常是一个预训练的图像分类网络,充当整个模型的特征提取器。

  • SSD Head 只是添加到该 Backbone 的一个或多个卷积层,它将通过边界框坐标为我们提供所需对象的位置和类别。

  • 对于我们的模型,我们将使用 MobileNet 作为我们的特征提取器。这就是它被称为 SSD MobileNet 模型的原因。

43bdbb7d592ce72be9a500f9a798fb94.png

这张图有两部分,第一部分是白色方框,代表Mobile Net架构的网络,第二部分是蓝色方框,代表SSD head

生成合成数据:

我们将使用简单的数学方法进行数据合成。

  • 首先,我们将从七张原始图像中获取一张图像,然后仅在这幅选定的图像上尝试我们的增强。

  • 然后,我们将通过为特定符号位置提供随机符号来创建 7k 合成图像,从而使模型能够学习如何识别比正常情况更小的物体。

  • 下面的代码将帮助我们生成 7k 图像并将它们保存在数据框中,其中包含符号详细信息(类别编号)及其相应的边界框。

import pandas as pd
import random


columns = ["filename","xmin","ymin","xmax","ymax","class"]
df = pd.DataFrame(columns=['filename', 'xmin', 'ymin', 'xmax', 'ymax', 'class'])
symbol_height = 70
symbol_width = 160


for count in range(7000):
    synth_img = cv2.imread('/content/result_Page_7.jpg')
    top, left  = 435 , random.randint(-17,-15)+150
    bottom , right = top + 70 , left +160
    synth_img[(425):600, 100:750] = (255,255,255)


    for i in range(5):
        image = random.choice(glob('/content/symbols/*'))
        img = cv2.imread(image)
        img = cv2.cvtColor(img,cv2.COLOR_RGBA2BGR).astype(np.float32)
        img = cv2.resize(img,(symbol_width,symbol_height))
        
        synth_img[int(top):int(bottom), int(left):int(right)] = img
        classes = (image.split(".")[-2].split('/')[-1])
        row = ['augmented_98/'+ str(count)+'.jpg', left,top,right,bottom,classes]
        df.loc[len(df.index)] = row
        row = ['augmented_98/'+ str(count)+'.jpg', 1049,126,1244,227,4]
        df.loc[len(df.index)] = row


        left = right + random.randint(30,40)
        right = left + symbol_width


        if i == 3 :
          top, left  = 435 + 72 , symbol_width-10
          bottom , right = top + symbol_height , (2*symbol_width)-10
      
    cv2.imwrite('augmented_98/'+ str(count) +'.jpg',synth_img)
  • 具有六列的数据框,其中 xmin、ymin、xmax 和 ymax 是边界框坐标,而类(0 到 9)表示符号的编号。

3b3a2f05ed877d5a1f1ce36a37f43668.png

  • 我们可以通过使用 matplotlib 库可视化图像来验证我们的数据,以在以下代码片段的帮助下检查我们的合成数据边界框是否正确。

def show_output_with_bbox(filename, bboxes, labels, transform):


    image = cv2.imread(filename) # cv2.IMREAD_COLOR
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
    image = image/255.0
    aug_pipeline = get_aug_pipeline()
    transformed = aug_pipeline(image=image, bboxes=bboxes, labels=labels)


    img = transformed['image']
    bboxes = torch.as_tensor(transformed['bboxes'])
    bboxes = bboxes.detach().numpy()
    labels = transformed['labels']
    
    img_height = img.shape[1]
    img_width = img.shape[2]  
    fig, ax = plt.subplots(figsize=(5,5))
    ax.imshow(img.permute(1,2,0).numpy())  
    for bbox, class_name in zip(bboxes, labels):
        xmin = bbox[0]
        ymin = bbox [1]
        width = bbox[2] - xmin
        height = bbox[3] - ymin
        rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin, class_name, color='w')  
    
    plt.show()
    
################################################################################


trainFiles = df['filename'].unique().tolist()


for filename in trainFiles[:3]:
    records = df[df['filename']==filename]
    bboxes = records[['xmin', 'ymin', 'xmax', 'ymax']].values
    labels = records[['class']].values
    labels = [x[0] for x in labels]
    area = (bboxes[:,2]-bboxes[:,0]) * (bboxes[:,3]-bboxes[:,1])
    show_output_with_bbox(filename, bboxes, labels, get_aug_pipeline())

f13227625f7be0684b93c36cb2e4567d.png

验证边界框是否正确

训练模型:

  • 使用下面的代码片段,我们将能够从 torchvision 下载 ssdlite320_mobilenet_v3_large 模型,我们可以自定义模型以根据我们的要求进行分类,即总共 10 个类,包括一个背景类。

  • 正如我们之前计划的那样,我们不会在训练时进行验证,我们将利用训练本身的所有数据。我们将 7k 图像设置 batch size 为 5 进行 25 个 epoch 的训练。

  • 我们将学习率设置为 0.01,momentum 和 weight decay 分别设置为 0.9 和 0.0005。

SSD_MODEL = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=10)


torch.cuda.empty_cache()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


if torch.cuda.is_available():
    SSD_MODEL.cuda()


params = [p for p in SSD_MODEL.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=0.0005)
EPOCHS = 25


loss_stats = {
    'train': [],
    "val": []
}
# TRAINING
print("Begin training.")
for e in tqdm(range(EPOCHS)):
    epoch_loss = []
    SSD_MODEL.train()
    for i, data in enumerate(train_data_loader):
                                 
        images, targets = data
        optimizer.zero_grad()
  
        inputs = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]


        # pass input to model
        loss_dict_train = SSD_MODEL(inputs, targets)


        # loss
        losses_train = sum(loss for loss in loss_dict_train.values())
        epoch_loss.append(losses_train.item())


        # backprop
        losses_train.backward()


        # update weights
        optimizer.step()
        #------====------#
    
    # Epoch end - Training loss
    train_loss_epoch = np.mean(epoch_loss)
    loss_stats['train'].append(train_loss_epoch)
   
    print(f'Epoch {e+0:03}: | Train Loss: {train_loss_epoch :.5f}')
    torch.save({
    'epoch': e,
    'model_state_dict': SSD_MODEL.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    }, f"/content/drive/MyDrive/Glove_6B/augmented_98.pth")

857abfda91882f21cae04db13d97020f.png

最后五个 epoch 的训练损失

让我们验证模型输出:

  • 下面对原始图像的预测显示,模型在 20 个 epoch 后表现良好。

  • 边界框与类名是准确的。

  • 已捕获置信度得分(每个符号在 0 到 1 的范围内)。

3b695df8644cff870ea97a63345ab42e.png

元组的第一个元素是类别编号,第二个元素是置信度得分

使用 Paddle -OCR 进行文本提取:

什么是 Paddle-OCR:

  • 在深入了解 Paddle-OCR 之前,让我们先了解什么是 OCR。

  • OCR(光学字符识别)用于将基于文本的文档检测为数字文档。它主要分为通用 OCR 和特定领域 OCR。

  • Paddle -OCR 是一款开源的通用 OCR 工具,使用“Paddle”算法(水平识别)检测文本,占用内存极少,支持多语言。

  • 该模型可以识别静态和移动图片中的文本,而不管它们的方向和语言如何。识别的响应时间要少得多(ms)。

paddle -OCR 的用例:

  • 在金融业务中,可以用于支票簿、发票、个人报表、收据等业务单据的信息提取。

  • 工厂自动化用于冲压和读取带有序列号的雕刻零件,以避免生产线出错。

  • 在政府业务中,在机场,它用于护照识别和信息提取。它也可以用于交通标志识别。

让我们进行编程并提取一些细节:

Python 通过在一行代码中实现模型使我们的生活变得轻松。下面的代码片段展示了我们如何从 Paddle -OCR 模型创建一个对象并提取所需的信息,然后我们可以将其保存为我们想要的格式。

# READING TEXT FROM IMAGES USING PADDLE-OCR


    result = ocr.ocr(image, cls=True)
    for i in (result[0]):
        if i[1][0].startswith("Device Name:"):
            before_keyword1, keyword1, Device_name = i[1][0].partition('Device Name:')
            device_Name_list.append(Device_name)
        elif i[1][0].startswith("REF"):
            before_keyword2, keyword2, REF = i[1][0].partition('REF')
            REF_list.append(REF)
        elif i[1][0].startswith("LOT:"):
            before_keyword3, keyword3, LOT = i[1][0].partition('LOT:') 
            LOT_list.append(LOT)
        elif i[1][0].startswith("Qty:"):
            before_keyword4, keyword4, Qty = i[1][0].partition('Qty:')
            Qty_list.append(Qty)


#CREATING DATAFRAME TO SAVE IN EXCEL FILE


data = pd.DataFrame(columns = ['Device Name','REF','LOT','Qty','Symbols'])
data['Device Name'] = device_Name_list
data['REF'] = REF_list
data['LOT'] = LOT_list
data['Qty'] = Qty_list
data['Symbols'] = symbol_list
data.to_excel("output.xlsx",index=False)

431081a696998639ed9575b656fd4aeb.png

原始图像所需的详细信息

结论 

  • 在这个数字信息时代,如果需要,我们可以尝试使用人工智能解决每一个业务问题。

  • 即使可用数据较少,我们也可以使用深度学习/机器学习技术来解决我们的问题。

·  END  ·

HAPPY LIFE

a31b5c7fe2d40e5d129536ee84953210.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/188078.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UniApp已经接了手机数据线,但运行工具警告 “没有检查到设备“ (华为手机为例 进行解决)

大部分第一次使用uni进行手机调试都会遇到这个问题 首先 将手机的数据线插入电脑的usb接口是必备前提 然后 就是手机的权限拦截了设备扫描 这就是uni工具找不到设备的原因 接入手机线后 数据会弹出一个USB的提示 点进去之后 我们要设置 允许传输文件 千万别仅充电 接下来的…

Java 以数据流的形式发送数据request Java 数据封装到request中

Java 以数据流的形式发送数据request Java 数据封装到request中 一、描述 1、在做微信支付结果通知的时候,看到一个描述:微信会把相关支付结果及用户信息通过数据流的形式发送给商户 ,那么java如何通过数据流的形式发送数据呢? 二…

idea中的Debug工具的使用介绍

文章目录1、设置断点给断点添加条件2、打开DebugDebu启动方式3、Debug功能介绍左侧功能区顶部功能区使用Debug工具时要先进行打断点的操作1、设置断点 断点就是程序运行暂停的位置,在这个位置以后可以根据自己的操作一步一步的执行程序。 idea中设置断点&#xff1…

FreeMarker基础知识

1、总览 官网:http://freemarker.foofun.cn/ 视频地址:https://www.bilibili.com/video/BV1zZ4y1u7iA 2、FreeMarker概述 2.1 FreeMarker概念 FreeMarker 是⼀款 模板引擎: 即⼀种基于模板和要改变的数据, 并⽤来⽣成输出⽂本(…

动态化护眼全新体验,被誉为“护眼神器”的南卡护眼台灯Pro评测出炉

自从家中的孩子上小学后,随着课后作业的逐渐增加,在书房学习时间更长了,由于平时关注到孩子用眼习惯,眼睛有些轻度近视。作为年轻一代的家长,对孩子的用眼健康方面一定要重视,在照明方面,护眼台…

Redis基础篇:Redis简介和安装

第一章:Redis简介 一:简介 Redis诞生于2009年,基于内存的键值型NoSQL数据库。 二:特征 1:键值型:value支持多种不同的数据结构,功能丰富。 2:单线程:单线程执行命令&…

Kubernetes介绍

1 什么是Kubernetes? Kubernetes是容器集群管理系统,是一个开源的平台,可以实现容器集群的自动化部署、自动扩缩容、维护等功能。 使用Kubernetes可以: ● 自动化容器的部署和复制 ● 随时扩展或收缩容器规模 ● 将容器组织成组&…

第四章.神经网络—单层感知器

第四章.神经网络 4.1 单层感知器 1.单层感知器示意图 1).第一种表示方法: 举例说明: 2).第二种表示方法: 公式推导: 举例说明: 预测值(y)和标签值(t)相同,停止迭代循环. 2.学习率η 1).η取值说明&…

Python流程控制语句之跳转语句

上一篇:Python流程控制语句之循环语句 文章目录前言一、break 语句二、continue 语句三、pass 空语句总结前言 上一篇博客我们讲解了Python中的循环语句,知道循环条件一直满足时,代码将会一直执行下去,就像一辆迷路的车&#xff…

《满江红》《流浪地球2》孰能胜出,元宇宙电影能否成为票房黑马?

截止1月28日12时,2023年春节档期总票房达67.57亿元。其中,《满江红》以26.05亿元票房居2023年春节档票房榜榜首;《流浪地球2》位居第二,票房成绩为21.63亿元。摆在未来人类面前就两条路,一条向外星辰大海,一条向内元宇宙。《流浪地…

微信小程序017音乐播放器系统 php java

小程序前端框架:uniapp 小程序运行软件:微信开发者 后端技术:javaSsm(SpringSpringMVCMyBatis)vue.js 后端开发环境:idea/eclipse 数据库:mysql 基于音乐播放器小程序的设计基于现有的手机,可以实现首页、个人中心、用户管理,音乐…

拉伯证券|开盘暴跌20%,三文鱼第一股业绩变脸!

超900家公司成绩预亏,多家公司发布成绩预告后大跌。 佳沃食品今天开盘20%跌停,这是该股史上开盘最大跌幅。早盘该股成交额显着扩展,半日成交额超越3.5亿元,收盘跌18.04%。 资料显现,佳沃食品是优质蛋白食品领域的大消…

python入门教程(非常详细),python贪吃蛇最简单代码

大家好,小编来为大家解答以下问题,python编程代码大全设计入门,python入门教程(非常详细),现在让我们一起来看看吧! 1、python编程例子有哪些? python编程经典例子: 1、画爱心表白、图形都是由…

除了Navicat破解版、DBeaver,免费还好用的数据库管理工具/SQL工具还有推荐吗?

很多国内SQL学习者和开发者对Navicat、DBeaver等国外数据库管理工具已经很熟悉了。但是,有没有比他们更适合SQL开发者的数据库管理/SQL工具呢?这里,笔者结合自己的调研来聊一下。 笔者做过一些用户调研。 Navicat虽然功能强大,但…

win10安装opencv

第一步:会有skbuild,cmake等依赖库报错,先安装依赖pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple scikit-buildpip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple cmake第二步:pip3 install opencv-python若…

Python数据可视化之折线图

Python数据可视化之折线图 提示:前言 Python数据可视化之折线图 提示:写完文章后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录Python数据可视化之折线图前言一、导入包二、选择数据集三、折线图四、图形的大小和图表…

什么游戏视频录制软件比较好?10 款的游戏录屏软件你值得收藏

市面上有各种各样的游戏捕捉软件,当然,它们都声称是有史以来最好的游戏软件。但有些比其他的更好,最适合您的游戏记录器在很大程度上取决于您要玩的游戏以及您运行的 PC 类型。 目前最好的游戏屏幕录像机 让我们来探索自称是最佳游戏屏幕录…

NetLogo 语法总结

NetLogo 语法总结NetLogo语法的怪异。。。。。。NetLogo语法关键在于你要把它当成一个软件使用,而不是一个通用的编程语言。首先,上网搜搜setup go是怎么用的,或者买本书,本文不再赘述NetLogo世界turtlespatcheslinksobserver(上帝…

np.savetxt()存储数据

前言 使用np.savetxt()方法可以将数据保存为txt文件或者是csv文件。 1 np.savetxt()存储txt文件 1-1 基础参数 numpy.savetxt(fname,arrry,fmt%.18e,delimiter ,newline\n,header,footer,comments# ,encodingNone,) 1-2 参数详解 fname:要存入的文件、文件名、或生成器。 ar…

令人窒息的百度面试题(正值换工作季,还不收藏???)

最近去网上找了一些百度的面经,冥冥之中在众多的面试题中打开了下边两个面试题: 2021百度前端社招面经 百度前端面试题分享,带答案 看完之后我直呼“哇哦~”,全部在我的射程范围之内。我该不会如此幸运到问的全会吧。 是的&am…