基于YOLO的车牌检测识别(YOLO+Transformer)

news2024/11/13 12:14:57

概述:
基于深度学习的车牌识别,其中,车辆检测网络直接使用YOLO侦测。而后,才是使用网络侦测车牌与识别车牌号。

车牌的侦测网络,采用的是resnet18,网络输出检测边框的仿射变换矩阵,可检测任意形状的四边形。

车牌号序列模型,采用Resnet18+transformer模型,直接输出车牌号序列。

数据集上,车牌检测使用CCPD 2019数据集,在训练检测模型的时候,会使用程序生成虚假的车牌,覆盖于数据集图片上,来加强检测的能力。

车牌号的序列识别,直接使用程序生成的车牌图片训练,并佐以适当的图像增强手段。模型的训练直接采用端到端的训练方式,输入图片,直接输出车牌号序列,损失采用CTCLoss。

一、网络模型
1、车牌的侦测网络模型:

网络代码定义如下:

class WpodNet(nn.Module):

    def __init__(self):
        """
        车牌侦测网络,直接使用Resnet18,仅改变输出层。
        """
        super(WpodNet, self).__init__()
        resnet = resnet18(True)
        backbone = list(resnet.children())
        self.backbone = nn.Sequential(
            nn.BatchNorm2d(3),
            *backbone[:3],
            *backbone[4:8],
        )
        self.detection = nn.Conv2d(512, 8, 3, 1, 1)

    def forward(self, x):
        features = self.backbone(x)
        out = self.detection(features)
        out = rearrange(out, 'n c h w -> n h w c') # 变换形状
        return out

该网络,相当于直接对图片划分cell,即在16X16的格子中,侦测车牌,输出的为该车牌边框的反射变换矩阵。

2、车牌号的序列识别网络:
车牌号序列识别的主干网络:采用的是ResNet18+transformer,其中有ResNet18完成对图片的编码工作,再由transformer解码为对应的字符。

网络代码定义如下:

from torch import nn
from torchvision.models import resnet18
import torch
from einops import rearrange


class OcrNet(nn.Module):

    def __init__(self,num_class):
        super(OcrNet, self).__init__()
        resnet = resnet18(True)
        backbone = list(resnet.children())
        self.backbone = nn.Sequential(
            nn.BatchNorm2d(3),
            *backbone[:3],
            *backbone[4:8],
        )  # 创建ResNet18
        self.decoder = nn.Sequential(
            Block(512, 8, False),
            Block(512, 8, False),
        )  # 由Transformer 构成的解码器
        self.out_layer = nn.Linear(512, num_class)  # 线性输出层
        self.abs_pos_emb = AbsPosEmb((3, 9), 512)  # 绝对位置编码

    def forward(self,x):
        x = self.backbone(x)
        x = rearrange(x,'n c h w -> n (w h) c')
        x = x + self.abs_pos_emb()
        x = self.decoder(x)
        x = rearrange(x, 'n s v -> s n v')
        return self.out_layer(x)

其中的Block类的代码如下:

class Block(nn.Module):
    r"""

    Args:
        embed_dim: 词向量的特征数。
        num_head: 多头注意力的头数。
        is_mask: 是否添加掩码。是,则网络只能看到每个词前的内容,而无法看到后面的内容。

    Shape:
        - Input: N,S,V (批次,序列数,词向量特征数)
        - Output:same shape as the input

    Examples::
        # >>> m = Block(720, 12)
        # >>> x = torch.randn(4, 13, 720)
        # >>> output = m(x)
        # >>> print(output.shape)
        # torch.Size([4, 13, 720])
    """

    def __init__(self, embed_dim, num_head, is_mask):
        super(Block, self).__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.attention = SelfAttention(embed_dim, num_head, is_mask)
        self.ln_2 = nn.LayerNorm(embed_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 6),
            nn.ReLU(),
            nn.Linear(embed_dim * 6, embed_dim)
        )

    def forward(self, x):
        '''计算多头自注意力'''
        attention = self.attention(self.ln_1(x))
        '''残差'''
        x = attention + x
        x = self.ln_2(x)
        '''计算feed forward部分'''
        h = self.feed_forward(x)
        x = h + x  # 增加残差
        return x

位置编码的代码如下:

class AbsPosEmb(nn.Module):
    def __init__(
        self,
        fmap_size,
        dim_head
    ):
        super().__init__()
        height, width = fmap_size
        scale = dim_head ** -0.5
        self.height = nn.Parameter(torch.randn(height, dim_head) * scale)
        self.width = nn.Parameter(torch.randn(width, dim_head) * scale)

    def forward(self):
        emb = rearrange(self.height, 'h d -> h () d') + rearrange(self.width, 'w d -> () w d')
        emb = rearrange(emb, ' h w d -> (w h) d')
        return emb

Block类使用的自注意力代码如下:

class SelfAttention(nn.Module):
    r"""多头自注意力

    Args:
        embed_dim: 词向量的特征数。
        num_head: 多头注意力的头数。
        is_mask: 是否添加掩码。是,则网络只能看到每个词前的内容,而无法看到后面的内容。

    Shape:
        - Input: N,S,V (批次,序列数,词向量特征数)
        - Output:same shape as the input

    Examples::
        # >>> m = SelfAttention(720, 12)
        # >>> x = torch.randn(4, 13, 720)
        # >>> output = m(x)
        # >>> print(output.shape)
        # torch.Size([4, 13, 720])
    """

    def __init__(self, embed_dim, num_head, is_mask=True):
        super(SelfAttention, self).__init__()
        assert embed_dim % num_head == 0
        self.num_head = num_head
        self.is_mask = is_mask
        self.linear1 = nn.Linear(embed_dim, 3 * embed_dim)
        self.linear2 = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        '''x 形状 N,S,V'''
        x = self.linear1(x)  # 形状变换为N,S,3V
        n, s, v = x.shape
        """分出头来,形状变换为 N,S,H,V"""
        x = x.reshape(n, s, self.num_head, -1)
        """换轴,形状变换至 N,H,S,V"""
        x = torch.transpose(x, 1, 2)
        '''分出Q,K,V'''
        query, key, value = torch.chunk(x, 3, -1)
        dk = value.shape[-1] ** 0.5
        '''计算自注意力'''
        w = torch.matmul(query, key.transpose(-1, -2)) / dk  # w 形状 N,H,S,S
        if self.is_mask:
            """生成掩码"""
            mask = torch.tril(torch.ones(w.shape[-1], w.shape[-1])).to(w.device)
            w = w * mask - 1e10 * (1 - mask)
        w = torch.softmax(w, dim=-1)  # softmax归一化
        attention = torch.matmul(w, value)  # 各个向量根据得分合并合并, 形状 N,H,S,V
        '''换轴至 N,S,H,V'''
        attention = attention.permute(0, 2, 1, 3)
        n, s, h, v = attention.shape
        '''合并H,V,相当于吧每个头的结果cat在一起。形状至N,S,V'''
        attention = attention.reshape(n, s, h * v)
        return self.linear2(attention)  # 经过线性层后输出

二、数据加载
1、车牌号的数据加载
通过程序生成一组车牌号:
在这里插入图片描述

再通过数据增强,主要包括:
随机污损:
在这里插入图片描述
高斯模糊:
在这里插入图片描述
仿射变换,粘贴于一张大图中:
在这里插入图片描述
四边形的四个角的位置随机偏移些许后扣出:
在这里插入图片描述

然后直接训练车牌号的序列识别网络,

loss_func = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00001)

优化器直接使用Adam,损失函数为CTCLoss。

2、车牌检测的数据加载
数据使用的是CCPD数据集,在这过程中,会随机的使用生成车牌,覆盖原始图片的车牌位置,来训练网络对车牌的检测能力。

if random.random() < 0.5:
    plate, _ = self.draw()
    plate = cv2.cvtColor(plate, cv2.COLOR_RGB2BGR)
    plate = self.smudge(plate)  # 随机污损
    image = enhance.apply_plate(image, points, plate)  # 粘贴车牌图片于数据图中
[x1, y1, x2, y2, x4, y4, x3, y3] = points
points = [x1, x2, x3, x4, y1, y2, y3, y4]
image, pts = enhance.augment_detect(image, points, 208)

三、训练
分别训练即可
其中,侦测网络的损失计算,如下:

def count_loss(self, predict, target):
    condition_positive = target[:, :, :, 0] == 1  # 筛选标签
    condition_negative = target[:, :, :, 0] == 0

    predict_positive = predict[condition_positive]
    predict_negative = predict[condition_negative]

    target_positive = target[condition_positive]
    target_negative = target[condition_negative]
    n, v = predict_positive.shape
    if n > 0:
        loss_c_positive = self.c_loss(predict_positive[:, 0:2], target_positive[:, 0].long())
    else:
        loss_c_positive = 0
    loss_c_nagative = self.c_loss(predict_negative[:, 0:2], target_negative[:, 0].long())
    loss_c = loss_c_nagative + loss_c_positive

    if n > 0:
        affine = torch.cat(
            (
                predict_positive[:, 2:3],
                predict_positive[:,3:4],
                predict_positive[:,4:5],
                predict_positive[:,5:6],
                predict_positive[:,6:7],
                predict_positive[:,7:8]
            ),
            dim=1
        )
        # print(affine.shape)
        # exit()
        trans_m = affine.reshape(-1, 2, 3)
        unit = torch.tensor([[-0.5, -0.5, 1], [0.5, -0.5, 1], [0.5, 0.5, 1], [-0.5, 0.5, 1]]).transpose(0, 1).to(
            trans_m.device).float()
        # print(unit)
        point_pred = torch.einsum('n j k, k d -> n j d', trans_m, unit)
        point_pred = rearrange(point_pred, 'n j k -> n (j k)')
        loss_p = self.l1_loss(point_pred, target_positive[:, 1:])
    else:
        loss_p = 0
    # exit()
    return loss_c, loss_p

侦测网络输出的反射变换矩阵,但对车牌位置的标签给的是四个角点的位置,所以需要响应转换后,做损失。其中,该cell是否有目标,使用CrossEntropyLoss,而对车牌位置损失,采用的则是L1Loss。

四、推理

根目录下运行,

python kenshutsu.py

记得修改py文件中的模型权重路径位置。

在这里插入图片描述

推理解析:
1、侦测网络的推理
按照一般侦测网络,推理即可。只是,多了一步将反射变换矩阵转换为边框位置的计算。
另外,在YOLO侦测到得测量图片传入该级进行车牌检测的时候,会做一步操作。代码见下,将车辆检测框的图片扣出,然后resize到长宽均为16的整数倍。

h, w, c = image.shape
f = min(288 * max(h, w) / min(h, w), 608) / min(h, w)
_w = int(w * f) + (0 if w % 16 == 0 else 16 - w % 16)
_h = int(h * f) + (0 if h % 16 == 0 else 16 - h % 16)
image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA)

在这里插入图片描述

2、序列检测网络的推理
对网络输出的序列,进行去重操作即可,如间隔标识符为“*”时:

def deduplication(self, c):
    '''符号去重'''
    temp = ''
    new = ''
    for i in c:
        if i == temp:
            continue
        else:
            if i == '*':
                temp = i
                continue
            new += i
            temp = i
    return new

五、完整代码

https://github.com/HibikiJie/LicensePlate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2091924.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

同城小程序怎么做 同城小程序系统开发制作方案

很多同城创业的老板们想要做一个同城小程序但是不知道怎么做&#xff0c;本次瀚林就为大家详细介绍一下做同城小程序系统开发制作方法&#xff0c;给大家做个参考。 目前同城类型的小程序系统市面上比较常见的有&#xff1a;同城配送、鲜花订花、同城上门服务、同城跑腿、同城便…

中仕公考怎么样?事业编考试怎么备考?

事业编考试备考可以大致分为三个阶段&#xff0c;按照不同阶段根据自身的学习情况制定不同的学习计划即可。 ①基础阶段 有备考经验的考生可以忽略这一步&#xff0c;刚开始先打好基础很重要&#xff0c;根据课程和教材理解知识点&#xff0c;按照模块学习&#xff0c;对考试…

cnocr 安装

打开终端 如果不会打开终端 -> 终端打开输入 pip install cnocr 执行中途可能报错 去这里下载工具&#xff1a;c构建工具下载完打开&#xff0c;勾选这个 然后点安装安装完回到第2步重新执行

docker镜像所使用到的COW写时复制技术是什么

copy on write 简单来说&#xff0c;所有的读操作都是指向一份内存地址&#xff0c;共享这些数据&#xff0c;节省内存空间。 如果有进程要对数据进行写操作&#xff0c;系统会检测到这个行为&#xff0c;将数据复制一份出来&#xff0c;给这个进程进行写操作。其他进程继续…

5.3二叉树——二叉树链式结构实现

本篇博客梳理二叉树链式结构 明确&#xff1a;二叉树是递归定义的 递归的本质&#xff1a;当前问题子问题&#xff0c;返回条件是最小规模的子问题 一、二叉树的遍历 1&#xff0e;前序、中序与后序遍历 &#xff08;1&#xff09;前序&#xff1a;根->左子树->右子树…

全球知名度最高的华人颜廷利:世界公认十大思想家哲学家

全球知名度最高的华人颜廷利&#xff1a;世界公认十大思想家哲学家 在汉语这一中国优秀传统文化的瑰宝中&#xff0c;“色”与“舍”这两个字的发音分别被解读为“思恶”和“识恶”&#xff0c;揭示了一种深奥的文化现象。这种现象的根源&#xff0c;实则来自于我们的感官——眼…

linux上查找某应用所在的绝对路径

linux上查找某应用所在的绝对路径 1、已知应用名称 找到应用的进程号 例&#xff1a;查找nginx的进程号 ps -ef | grep nginx 或者 ps -aux | grep nginx 2、通过端口号找进程号 lsof -i:80 3、通过进程号找到所在目录&#xff0c;Linux在启动一个进程时,系统会在/proc目…

力扣刷题(3)

整数反转 整数反转-力扣 思路&#xff1a; 利用%和/不断循环取待反转整数的最后一位&#xff0c;注意判断是否超出范围。 int reverse(int x){int y0;while(x){if(y > INT_MAX/10 || y < INT_MIN/10)return 0;int tmpx%10;yy*10tmp;x/10;}return y; }字符串转换整数 …

多线程篇(基本认识 - 锁优化)(持续更新迭代)

目录 一、前言 二、阿里开发手册 三、synchronized 锁优化的背景 四、Synchronized的性能变化 1. Java5之前&#xff1a;用户态和内核态之间的切换 2. java6开始&#xff1a;优化Synchronized 五、锁升级 1. 无锁 2. 偏向锁 2.1. 前言 2.2. 什么是偏向锁 2.3. 偏向…

知识产权案件中的消费者问卷调查证据

在知识产权案件中&#xff0c;消费者问卷调查可以作为一种重要的证据形式。通过调查消费者的认知、态度、行为和观点&#xff0c;消费者问卷调查可以提供以下方面的证据支持&#xff1a; 1、商标或产品混淆&#xff1a;消费者问卷调查可以确定消费者对于涉及知识产权的商标或产…

《python语言程序设计》第8章第9题将二进制数作为字符串转换十六进制print和return的区别

在这里我发现了return和print的区别 def binary_to_hex(binary_value):len_text len(binary_value)for i in range(0, len_text, 4):#能把二进制分成四组进行打印print(binary_value[0 i:4 i])#只能运行将前4个数分成一组return binary_value[0 i:4 i]a binary_to_hex(&q…

HarmonyOS--AGC(认证服务/云函数/云存储/云数据库)

HarmonyOS–AGC(认证服务/云函数/云存储/云数据库) 文章目录 一、注册华为账号开通认证服务二、添加项目&#xff1a;*包名要与项目的包名保持一致三、获取需要的文件四、创建项目&#xff1a;*包名要与项目的包名保持一致五、添加json文件六、加入请求权限七、加入依赖八、修改…

Openai api via azure error: NotFoundError: 404 Resource not found

题意&#xff1a;"OpenAI API通过Azure出错&#xff1a;NotFoundError: 404 找不到资源" 问题背景&#xff1a; thanks to the university account my team and I were able to get openai credits through microsoft azure. The problem is that now, trying to us…

VS2022搭建QT及OpenCV环境

1.背景 由于之前VS2022和QT已经安装好了&#xff0c;所以本次的任务主要是下载OpenCV以及在VS2022上集成QT和OpenCV。关于VS2022和QT的安装大家可以参考别的博客。QT选择的版本是6.2.4&#xff0c;OpenCV版本为3.4.5&#xff0c;Windows版本为Win11。 2.OpenCV下载 OpenCV官…

《黑神话:深度探索与攻略指南》——虎先锋隐藏门在哪里

在《黑神话悟空》这款扣人心弦的游戏中&#xff0c;探索隐藏区域和发现秘密宝箱是许多玩家的乐趣所在。特别是戌狗地窖中那个神秘的宝箱&#xff0c;它不仅藏有泡酒物虎舍利等珍贵道具&#xff0c;更是对玩家探索能力的一次考验。然而&#xff0c;不少玩家在寻找虎先锋隐藏门时…

raksmart机云大宽带服务器托管服务内容

RakSmart是一家提供全球数据中心服务的公司&#xff0c;其业务范围涵盖了服务器托管、专用服务器租赁、云服务等多个领域。其中&#xff0c;机柜大带宽服务器托管服务是其特色之一&#xff0c;特别适合那些需要大量带宽资源的企业级客户。下面我们将详细介绍RakSmart的机柜大带…

windows系统安装配置Apache Maven

Date: 2024.07.17 09:45:10 author: lijianzhan 电脑环境: win10系统 Java开发环境: JDK21 Mvn : apache-maven-3.9.9 maven下载地址: https://maven.apache.org/download.cgi 点击链接进入Apache Maven官网&#xff0c;选择apache-maven-3.9.9-bin.zip进行下载。 下载maven…

ecmascript和javascript的区别详细讲解

​ 大家好&#xff0c;我是程序员小羊&#xff01; 前言&#xff1a; ECMAScript 和 JavaScript 是密切相关的两个概念&#xff0c;但它们在本质上有所区别。以下是对它们的详细介绍和区别分析。 一、概念定义 1. JavaScript 的定义 JavaScript 是一种基于原型的动态脚本语…

【hot100篇-python刷题记录】【最长连续序列】

R7-哈希篇 思路&#xff1a;sort一下先 然后使用双指针遍历计数&#xff0c;同时从数组头开始。 按照题目的意思&#xff0c;相同元素只能取一个&#xff0c;所以可以用set class Solution:def longestConsecutive(self, nums: List[int]) -> int:if len(nums)<1:retu…

开放式耳机的优缺点?五款绝佳王炸顶级力荐!

在运动过程中&#xff0c;我们往往会选择听音乐来提升运动表现&#xff0c;但是在专注音乐和运动时&#xff0c;又很容易忽略了周围的环境&#xff0c;导致运动意外的发生。所以开放式蓝牙耳机的兴起&#xff0c;深受大家的喜爱&#xff0c;尤其是运动爱好者和音乐爱好者&#…