基于YOLOv4的仪表检测
- 前言
- YOLOv4源码下载
- 数据集处理与模型训练
- 模型性能测试
前言
本系列是想记录一下自己实现的一种用于仪表检测与读数的方法,首先方法仅针对于单指针仪表和单行显示的数字仪表进行了检测与读数方法的设计。方法的整体思路是:第一步对拍摄图像进行仪表检测,然后对仪表位置进行裁剪,并根据仪表类别对应进行指针仪表读数识别和数字仪表读数识别,最后在将读数结果对应输出。整体流程如下:
考虑到读数识别均需依赖于仪表检测结果,因此首先实现仪表检测,因为现在基于深度学习的目标检测算法有非常好的效果,同时考虑到模型复杂性等问题,我选择了单阶段的目标检测网络YOLOv4来进行仪表检测(当然也可根据自身任务的难易程度选择更合适的网络),然后为了能够简化后续的流程,我考虑到检测任务同时还能够对目标进行分类,因此我选择将仪表按照不同的量程来作为分类标准,这样虽然加大了检测的难度(如果仅区分指针仪表或数字仪表,感觉对于YOLOv4又有些大材小用了),但是能够有效地简化后续的读数流程。(当然这种分类方式弊端也很多,如表明脏污、相似仪表、按量程区分仪表种类众多等问题,但这也算是提供了一种思路吧。)
YOLOv4源码下载
对于YOLOv4的网络结构及原理的介绍在网上有许多文章,或者可以阅读YOLOv4的论文。然后主要就是复现YOLOv4网络的源码,这里我选择的是Bubbliiiing博主实现的YOLOv4网络(包括后文还有另外一个网络的源码同样也是该博主的,非常感谢B导的无私奉献),对应博主的源码以及文章的链接见下面的两个链接(如果想了解更多,可以去B站看博主录制的视频):
文章: link.
源码: Github.
数据集处理与模型训练
在前言中也有介绍,我对仪表数据的标注方式是按照不同量程进行区分标注,如0-160℃单指针温度表和0-200℃单指针温度表是被划分为两类的。因此按照这种标注方式,需要同类型的仪表数量居多,我在网络上没有收集到足够的仪表图像作为数据集,绝大部分图像是自行拍摄的图像。图像拍摄不做过多赘述。
数据标注采用labelimg软件标注,下载地址:
Labelimg: 下载地址
软件的具体使用方法可自行搜索,标注的形式如下图,每一种仪表都按照仪表量程进行了细分类,如160℃和200℃的温度表被分别标注为160和200:
在标注完成后,会在指定目录生成对应的xml标注文件。这之后为了能够观察各类别仪表的样本数(理论上各样本数尽量均衡),单独写了一个脚本进行统计。
# coding:utf-8
"""
description: 统计各类别真实框数量,并显示一个统计图(主要是看看类别均衡)
author: kuang
time: 2020.6.16 by home
"""
import xml.etree.ElementTree as ET
import os
import matplotlib.pyplot as plt
import matplotlib
# 指定默认字体
matplotlib.rcParams['font.sans-serif'] = ['FangSong']
matplotlib.rcParams['font.family'] = 'sans-serif'
# xml所在路径
path = "./train/Annotations/"
def count():
# 获取所有xml文件
xml_list = os.listdir(path)
# 创建对应字典,且初始化
class_count = {}
# 对每个xml进行统计
for file in xml_list:
with open(os.path.join(path, file), "r", encoding="utf-8") as in_file:
# 打开xml文件
tree = ET.parse(in_file)
# 获取根节点
root = tree.getroot()
# 迭代每个object
for obj in root.iter('object'):
# 获取当前object的class name
cls = obj.find('name').text
if cls not in class_count.keys():
class_count[cls] = 1
# 计数
class_count[cls] += 1
return class_count
def draw(class_count):
x = class_count.keys()
y = class_count.values()
# 柱状图的宽度
width = 0.4
# 绘制条形统计图,横向
plt.bar(range(len(x)), y, align='center', color='steelblue', width=width, alpha=0.8)
# 添加轴标签
plt.ylabel('类别数量')
# 添加标题
plt.title('各类别数量统计图')
# 添加刻度标签
plt.xticks(range(len(x)), x)
# 为每个条形图添加数值标签,横向条形统计图
for i, j in enumerate(y):
plt.text(i - 0.1, j + 8, '%s' % str(list(y)[i]), va='center')
# 显示图形
plt.show()
if __name__ == "__main__":
class_count = count()
draw(class_count)
统计结果示例如下图:
最后,在按照博主在GitHub中给出的训练自己的数据集的训练步骤进行模型训练即可,注意根据自己硬件条件修改部分train.py文件中的参数(如batchsize、save_period等)。
模型性能测试
训练完成后,首先修改yolo.py文件中的model_path、classes_path的路径以及input_shape的大小,然后就可以运行项目中的predict.py文件调用训练后的模型进行仪表检测了,检测结果示意如下图所示。
在完成检测后,可得到仪表的类别以及对应仪表的位置,通过该位置坐标即可对仪表进行裁剪(将yolo.py文件中的detect_image函数的crop修改为true即可),得到仅包含单个仪表的图像,从而进行下一步的分类仪表读数识别。裁剪后图像示例如下。
最后,如果需要统计模型的mAP值,则可以运行get_map.py文件得到。