在使用MMDetection训练之前,需要对图像进行可视化验证,验证数据和标签是否对齐。
# 数据集可视化
import os
import matplotlib.pyplot as plt
from PIL import Image
original_images = []
images = []
texts = []
plt.figure(figsize=(16,12))
image_paths = [filename for filename in os.listdir(r"E:\****************************")][:8] # 取前8张图片
for i, filename in enumerate(image_paths):
name = os.path.splitext(filename)[0]
image = Image.open(os.path.join(r"E:\***************************",filename)).convert("RGB")
plt.subplot(4,2,i+1)
plt.imshow(image)
plt.title(f"{filename}")
plt.xticks([]) # 设置坐标轴
plt.yticks([])
plt.tight_layout()
plt.show()
以上代码 提供了数据集图片查看的功能,需要加入自己对应的图片路径。
以下代码 提供了COCO数据集标签与图片的显示功能,从数据集中随机选取了8张图片进行展示,以可视化数据集图片与标签是否对准。需要填入json路径和image的保存路径。
# COCO 数据集可视化
from pycocotools.coco import COCO
import numpy as np
import os.path as osp
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from PIL import Image
def apply_exif_orientation(image):
_ExIF_ORIENT = 274
if not hasattr(image,'getexif'):
return image
try:
exif = image.getexif()
except Exception:
exif = None
if exif is None:
return image
orientation = exif.get(_ExIF_ORIENT)
method = {
2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
if method is not None:
return image.transpose(method)
return image
def show_bbox_only(coco, anns, show_label_bbox = True, is_filling = True):
if len(anns) == 0:
return
ax = plt.gca()
ax.set_autoscale_on(False) # 自动调整坐标轴范围
image2color = dict()
for cat in coco.getCatIds():
image2color[cat] = (np.random.random((1, 3)) * 0.7 + 0.3).tolist()[0]
polygons = []
colors = []
for ann in anns:
color = image2color[ann["category_id"]]
bbox_xmin, bbox_ymin, bbox_w, bbox_h = ann['bbox']
poly = [[bbox_xmin, bbox_ymin],[bbox_xmin, bbox_ymin+bbox_h],
[bbox_xmin+bbox_w, bbox_ymin+bbox_h], [bbox_xmin+bbox_w, bbox_ymin]]
polygons.append(Polygon(np.array(poly).reshape((4,2))))
colors.append(color)
if show_label_bbox:
label_bbox = dict(facecolor = color)
else:
label_bbox = None
ax.text(
bbox_xmin,
bbox_ymin,
"%s" % (coco.loadCats(ann['category_id'])[0]['name']),
color = 'white',
bbox = label_bbox)
if is_filling:
p = PatchCollection(
polygons, facecolor = colors, linewidths = 0, alpha = 0.4)
ax.add_collection(p)
p = PatchCollection(
polygons, facecolor = None, linewidths = 0, alpha = 0.4)
ax.add_collection(p)
coco = COCO(r'E:\*******保存的json文件夹\test.json')
image_ids = coco.getImgIds()
np.random.shuffle(image_ids)
plt.figure(figsize=(16,12))
for i in range(8):
image_data = coco.loadImgs(image_ids[i])[0]
image_path = osp.join(r'E:\保存的图片文件夹',image_data['file_name'])
annotation_ids = coco.getAnnIds(
imgIds=image_data['id'], catIds=[], iscrowd=0
)
annotations = coco.loadAnns(annotation_ids)
ax = plt.subplot(4,2,i+1)
image = Image.open(image_path).convert('RGB')
image = apply_exif_orientation(image)
ax.imshow(image)
show_bbox_only(coco, annotations)
plt.title(f"{filename}")
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.show()