使用场景
检测数据集标注是否有误: 在目标检测算法中需要标注自己的数据集,为了更加方便的检查数据集标注是否有误,可以使用该工具将标注结果绘制在图像中并查看。美化识别结果中的检测框: 在一些目标检测场景中,YOLO检测算法原始的检测框绘制会导致重叠、颜色冲突、字体过大等问题。可以使用该工具进行修改。
代码
import os
import cv2
class check_label :
def __init__ ( self, classes: list , label_path: str , img_path: str , result_path: str ) :
self. classes = classes
self. line_width = 5
self. rec_color = ( 0 , 0 , 255 )
self. font_color = ( 255 , 255 , 255 )
self. font = cv2. FONT_HERSHEY_SIMPLEX
self. font_size = 5
self. font_thickness = 4
self. font_x_offset = 0
self. font_y_offset = - 15
self. isDrawFontRec = False
self. isShowFont = False
self. isShowConfidence = False
self. label_path = label_path
self. img_path = img_path
self. result_path = result_path
self. label_files = os. listdir( label_path)
self. img_files = os. listdir( img_path)
self. label_files. sort( key= lambda x: int ( x[ : - 4 ] ) )
self. img_files. sort( key= lambda x: int ( x[ : - 4 ] ) )
def paint ( self, imgName, pos) :
img = cv2. imread( self. img_path + "/" + imgName)
size = img. shape
imgW = size[ 1 ]
imgH = size[ 0 ]
for pos_i in pos:
pos_i = pos_i. split( ' ' )
x_center = float ( pos_i[ 1 ] ) * imgW + 1
y_center = float ( pos_i[ 2 ] ) * imgH + 1
x_min = int ( x_center - 0.5 * float ( pos_i[ 3 ] ) * imgW)
y_min = int ( y_center - 0.5 * float ( pos_i[ 4 ] ) * imgH)
x_max = int ( x_center + 0.5 * float ( pos_i[ 3 ] ) * imgW)
y_max = int ( y_center + 0.5 * float ( pos_i[ 4 ] ) * imgH)
x = x_min
y = y_min
w = x_max - x_min
h = y_max - y_min
b = 0.5
if self. isShowConfidence:
a = self. classes[ int ( pos_i[ 0 ] ) ]
else :
a = ""
cv2. rectangle( img, ( x, y) , ( x + w, y + h) , self. rec_color, self. line_width)
if self. isDrawFontRec:
cv2. rectangle( img, ( x + self. font_x_offset, y + self. font_y_offset) , ( x + w, y + abs ( self. font_y_offset) ) , self. rec_color,
- 1 )
if self. isShowFont:
cv2. putText( img, '{} {:.3f}' . format ( a, b) , ( x + self. font_x_offset, y + self. font_y_offset) , self. font, self. font_size,
self. font_color, self. font_thickness)
cv2. imwrite( self. result_path + "/" + imgName, img)
def process ( self) :
for label_file, img_file in zip ( self. label_files, self. img_files) :
print ( img_file, label_file)
if not os. path. isdir( label_file) :
f = open ( self. label_path + "/" + label_file, "r" , encoding= 'utf-8' )
result = f. read( ) . splitlines( )
self. paint( img_file, result)
f. close( )