文章目录
- 1. COCO数据读取
- 1.1 COCO数据集目录结构
- 1.2 pycocotools的使用
- 1.3 COCODetection类
- `__init__`方法
- `__getitem__`方法
- 2. VOC数据读取
- 2.1 VOC数据集目录结构
- 2.2 VOCInstances类
- `__init__`方法
- 2.3 `__getitem__`方法
- 参考
1. COCO数据读取
1.1 COCO数据集目录结构
下载并解压COCO数据集,可得到如下文件夹结构:
├── coco2017: 数据集根目录
├── train2017: 所有训练图像文件夹(118287张)
├── val2017: 所有验证图像文件夹(5000张)
└── annotations: 对应标注文件夹
├── instances_train2017.json: 对应目标检测、分割任务的训练集标注文件
├── instances_val2017.json: 对应目标检测、分割任务的验证集标注文件
├── captions_train2017.json: 对应图像描述的训练集标注文件
├── captions_val2017.json: 对应图像描述的验证集标注文件
├── person_keypoints_train2017.json: 对应人体关键点检测的训练集标注文件
└── person_keypoints_val2017.json: 对应人体关键点检测的验证集标注文件夹
- 在实例分割任务中,
train2017
保存所有训练图像 val2017
保存所有的验证图像- 对于标注文件,只使用到了
instances_train2017.json
和instances_val2017.json
两个标签文件。
1.2 pycocotools的使用
参考博文:MS COCO数据集介绍以及pycocotools简单使用
官方有给出一个读取MS COCO数据集信息的API(当然,该API还有其他重要功能),下面是对应github的连接,里面有关于该API的使用demo:
https://github.com/cocodataset/cocoapi
读取每张图片的bbox信息
下面是使用pycocotools
读取图像以及对应bbox信息的简单示例:
import os
from pycocotools.coco import COCO
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
json_path = "/data/coco2017/annotations/instances_val2017.json"
img_path = "/data/coco2017/val2017"
# load coco data
coco = COCO(annotation_file=json_path)
# get all image index info
ids = list(sorted(coco.imgs.keys()))
print("number of images: {}".format(len(ids)))
# get all coco class labels
coco_classes = dict([(v["id"], v["name"]) for k, v in coco.cats.items()])
# 遍历前三张图像
for img_id in ids[:3]:
# 获取对应图像id的所有annotations idx信息
ann_ids = coco.getAnnIds(imgIds=img_id)
# 根据annotations idx信息获取所有标注信息
targets = coco.loadAnns(ann_ids)
# get image file name
path = coco.loadImgs(img_id)[0]['file_name']
# read image
img = Image.open(os.path.join(img_path, path)).convert('RGB')
draw = ImageDraw.Draw(img)
# draw box to image
for target in targets:
x, y, w, h = target["bbox"]
x1, y1, x2, y2 = x, y, int(x + w), int(y + h)
draw.rectangle((x1, y1, x2, y2))
draw.text((x1, y1), coco_classes[target["category_id"]])
# show image
plt.imshow(img)
plt.show()
通过pycocotools
读取的图像以及对应的targets
信息,配合matplotlib
库绘制标注图像如下:
读取每张图像的segmentation信息
下面是使用pycocotools
读取图像segmentation
信息的简单示例:
import os
import random
import numpy as np
from pycocotools.coco import COCO
from pycocotools import mask as coco_mask
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
random.seed(0)
json_path = "/data/coco2017/annotations/instances_val2017.json"
img_path = "/data/coco2017/val2017"
# random pallette
pallette = [0, 0, 0] + [random.randint(0, 255) for _ in range(255*3)]
# load coco data
coco = COCO(annotation_file=json_path)
# get all image index info
ids = list(sorted(coco.imgs.keys()))
print("number of images: {}".format(len(ids)))
# get all coco class labels
coco_classes = dict([(v["id"], v["name"]) for k, v in coco.cats.items()])
# 遍历前三张图像
for img_id in ids[:3]:
# 获取对应图像id的所有annotations idx信息
ann_ids = coco.getAnnIds(imgIds=img_id)
# 根据annotations idx信息获取所有标注信息
targets = coco.loadAnns(ann_ids)
# get image file name
path = coco.loadImgs(img_id)[0]['file_name']
# read image
img = Image.open(os.path.join(img_path, path)).convert('RGB')
img_w, img_h = img.size
masks = []
cats = []
for target in targets:
cats.append(target["category_id"]) # get object class id
polygons = target["segmentation"] # get object polygons
rles = coco_mask.frPyObjects(polygons, img_h, img_w)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = mask.any(axis=2)
masks.append(mask)
cats = np.array(cats, dtype=np.int32)
if masks:
masks = np.stack(masks, axis=0)
else:
masks = np.zeros((0, height, width), dtype=np.uint8)
# merge all instance masks into a single segmentation map
# with its corresponding categories
target = (masks * cats[:, None, None]).max(axis=0)
# discard overlapping instances
target[masks.sum(0) > 1] = 255
target = Image.fromarray(target.astype(np.uint8))
target.putpalette(pallette)
plt.imshow(target)
plt.show()
通过pycocotools
读取的图像segmentation
信息,配合matplotlib库绘制标注图像如下:
读取人体关键点信息
在MS COCO任务中,对每个人体
都标注了17的关键点
,这17个关键点的部位分别如下:
["nose","left_eye","right_eye","left_ear","right_ear","left_shoulder","right_shoulder","left_elbow","right_elbow","left_wrist","right_wrist","left_hip","right_hip","left_knee","right_knee","left_ankle","right_ankle"]
在COCO给出的标注文件中,针对每个人体的标注格式如下所示。其中每3个值为一个关键点的相关信息,因为有17个关键点所以总共有51个数值。按照3个一组进行划分,前2个值代表关键点的x,y坐标,第3个值代表该关键点的可见度,它只会取{ 0 , 1 , 2 }
三个值。0表示该点一般是在图像外无法标注,1表示虽然该点不可见但大概能猜测出位置(比如人侧着站时虽然有一只耳朵被挡住了,但大概也能猜出位置),2表示该点可见。如果第3个值为0,那么对应的x,y也都等于0:
下面是使用pycocotools
读取图像keypoints
信息的简单示例:
import numpy as np
from pycocotools.coco import COCO
json_path = "/data/coco2017/annotations/person_keypoints_val2017.json"
coco = COCO(json_path)
img_ids = list(sorted(coco.imgs.keys()))
# 遍历前5张图片中的人体关键点信息(注意,并不是每张图片里都有人体信息)
for img_id in img_ids[:5]:
idx = 0
img_info = coco.loadImgs(img_id)[0]
ann_ids = coco.getAnnIds(imgIds=img_id)
anns = coco.loadAnns(ann_ids)
for ann in anns:
xmin, ymin, w, h = ann['bbox']
# 打印人体bbox信息
print(f"[image id: {img_id}] person {idx} bbox: [{xmin:.2f}, {ymin:.2f}, {xmin + w:.2f}, {ymin + h:.2f}]")
keypoints_info = np.array(ann["keypoints"]).reshape([-1, 3])
visible = keypoints_info[:, 2]
keypoints = keypoints_info[:, :2]
# 打印关键点信息以及可见度信息
print(f"[image id: {img_id}] person {idx} keypoints: {keypoints.tolist()}")
print(f"[image id: {img_id}] person {idx} keypoints visible: {visible.tolist()}")
idx += 1
终端输出信息如下,通过以下信息可知,验证集中前5张图片里只有一张图片包含人体关键点信息:
[image id: 139] person 0 bbox: [412.80, 157.61, 465.85, 295.62]
[image id: 139] person 0 keypoints: [[427, 170], [429, 169], [0, 0], [434, 168], [0, 0], [441, 177], [446, 177], [437, 200], [430, 206], [430, 220], [420, 215], [445, 226], [452, 223], [447, 260], [454, 257], [455, 290], [459, 286]]
[image id: 139] person 0 keypoints visible: [1, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
[image id: 139] person 1 bbox: [384.43, 172.21, 399.55, 207.95]
[image id: 139] person 1 keypoints: [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]
[image id: 139] person 1 keypoints visible: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1.3 COCODetection类
__init__
方法
def __init__(self, root, dataset="train", transforms=None, years="2017"):
super(CocoDetection, self).__init__()
assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
anno_file = f"instances_{dataset}{years}.json"
assert os.path.exists(root), "file '{}' does not exist.".format(root)
self.img_root = os.path.join(root, f"{dataset}{years}")
assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
self.anno_path = os.path.join(root, "annotations", anno_file)
assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)
self.mode = dataset
self.transforms = transforms
self.coco = COCO(self.anno_path)
# 获取coco数据索引与类别名称的关系
# 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的
data_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
max_index = max(data_classes.keys()) # 90
# 将缺失的类别名称设置成N/A
coco_classes = {}
for k in range(1, max_index + 1):
if k in data_classes:
coco_classes[k] = data_classes[k]
else:
coco_classes[k] = "N/A"
if dataset == "train":
json_str = json.dumps(coco_classes, indent=4)
with open("coco91_indices.json", "w") as f:
f.write(json_str)
self.coco_classes = coco_classes
ids = list(sorted(self.coco.imgs.keys()))
if dataset == "train":
# 移除没有目标,或者目标面积非常小的数据
valid_ids = coco_remove_images_without_annotations(self.coco, ids)
self.ids = valid_ids
else:
self.ids = ids
root
:解压后COCO2017
文件夹的根目录,dataset
: 指定读取训练集(train)还是验证集(val);trainsform
: 指定的数据增强方式;years
: coco数据集的年份,默认设置为2017
- 构建实例分割标注文件的名称,对应annotations文件夹下的
instance_train2017.json
和instance_val2017.json
,所以我们可以根据传入dataset对应train还是val,构建标注文件名。
anno_file = f"instances_{dataset}{years}.json"
- 构建
img_root
,anno_path
, 并assert判断是否存在,如果不存在则报错提示。 - 实例化
COCO
类,传入标注文件路径anno_path
- 接下来获取coco类的索引与类的关系,在COCO数据集中它的类别是按照
stuff91
类别来排序的,通过遍历self.coco.cats.itmes()
获取类别与索引的关系。调用max_index = max(data_classes.keys())
可以看到最大的索引为90
。data_classes.keys()
可以看出类别的索引并不是连续的,11直接到13,25直接到27。
dict_keys([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90])
- 接下来就重新构建了索引与类别的关系,将1,90索引中缺失的类别设置为
N/A
for k in range(1, max_index + 1):
if k in data_classes:
coco_classes[k] = data_classes[k]
else:
coco_classes[k] = "N/A"
并将coco_classs键值对,写入到coco91_indices.json
文件中
- 接下来保存所有图片的id,并移除图片中没有目标,或者面积非常小的图片
def coco_remove_images_without_annotations(dataset, ids):
"""
删除coco数据集中没有目标,或者目标面积非常小的数据
refer to:
https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
:param dataset:
:param cat_list:
:return:
"""
// 长宽<1
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
return True
valid_ids = []
for ds_idx, img_id in enumerate(ids):
ann_ids = dataset.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.loadAnns(ann_ids)
if _has_valid_annotation(anno):
valid_ids.append(img_id)
return valid_ids
__getitem__
方法
- 根据传入的索引
index
,获得img_id,根据img_id获取当前图片对应的ann_ids
, 根据anno_ids
获得该张图片的所有标注信息coco_target。设置断点查看获取的coco_target数据。
- 首先
coco_target
是个list列表,每个元素是一个dict字典,每个字典包含有:segmentation
,area
,iscrowd
,image_id
,is_crowed
等,其中is_crowed
为0表示单独的目标,为1的话代表可能有多个目标重叠在一起。 - 这里需要关注segmentation的信息,coco数据集对每个分割目标,
通过一个个点标注出分割的多边形轮廓
。后面需要一个个点
组成的多边形轮廓转换
为我们所需要的mask信息
。segmentation标注的一个个点的信息示例如下:
- 接下来读取图片转换为RGB,并利用
parse_targets
将coco_target转换为我们需要的target输出。
def parse_targets(self,
img_id: int,
coco_targets: list,
w: int = None,
h: int = None):
assert w > 0
assert h > 0
# 只筛选出单个对象的情况
anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
# [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_mask(segmentations, h, w)
# 筛选出合法的目标,即x_max>x_min且y_max>y_min
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
masks = masks[keep]
area = area[keep]
iscrowd = iscrowd[keep]
target = {}
target["boxes"] = boxes
target["labels"] = classes
target["masks"] = masks
target["image_id"] = torch.tensor([img_id])
# for conversion to coco api
target["area"] = area
target["iscrowd"] = iscrowd
return target
- 只保留iscrowd=0,也就是目标之间没有重叠的。iscrwed=1,一般是比较难处理
- 将输出的boxes的每个bdbox的形式由
[xmin,ymin,w,h]
,转换为[xmin,ymin,xmax,ymax]
,然后利用clamp将x坐标限制在0,w之间,将y坐标限制在0,h之间
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
构建mask
- 然后获取标注文件的目标的类别信息classes,iscrowed信息,segmentation信息。然后将segmentation信息转换为masks。
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_mask(segmentations, h, w)
def convert_coco_poly_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
# 如果mask为空,则说明没有目标,直接返回数值为0的mask
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
- 首先通过遍历
segmentations
,获取每个目标的多边形信息 - 然后通过
coco_mask
将多边形坐标转换为mask, 使用这个方法需要导入from pycocotools import mask as coco_mask
。
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
- 在mask出打断点,可以看出其shape大小为(480,640,1), 查看mask0通道的信息
make[:,:,0]
,可以看到这里的mask蒙版,背景
对应的是0
,前景
对应的是1
填充。如果mark的shape小于3,也就是说只有高宽没有通道信息,就添加一个channel信息。mask =mask[...,None]
- 通道方向只要有一个值为1,就认为是前景 mask =mask.any(dim=2), 然后将所有mask进行stack拼接,并返回。如果没有mask,直接使用torch.zeros构建一个全是背景的masks
def convert_coco_poly_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
# 如果mask为空,则说明没有目标,直接返回数值为0的mask
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
- 然后筛选合适的targes,即x_max > x_min并且y_max > y_min, 然后针对这些合法的目标提取出boxes,classes,masks,area,iscrowed,并保存到targe字典中进行返回。
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
masks = masks[keep]
area = area[keep]
iscrowd = iscrowd[keep]
target = {}
target["boxes"] = boxes
target["labels"] = classes
target["masks"] = masks
target["image_id"] = torch.tensor([img_id])
# for conversion to coco api
target["area"] = area
target["iscrowd"] = iscrowd
最后__getitem__
返回image,以及我们构建好的target信息。
2. VOC数据读取
2.1 VOC数据集目录结构
下载VOC2012数据集,并解压得到如下的文件目录结构
VOCdevkit
└── VOC2012
├── Annotations 所有的图像标注信息(XML文件)
├── ImageSets
│ ├── Action 人的行为动作图像信息
│ ├── Layout 人的各个部位图像信息
│ │
│ ├── Main 目标检测分类图像信息
│ │ ├── train.txt 训练集(5717)
│ │ ├── val.txt 验证集(5823)
│ │ └── trainval.txt 训练集+验证集(11540)
│ │
│ └── Segmentation 目标分割图像信息
│ ├── train.txt 训练集(1464)
│ ├── val.txt 验证集(1449)
│ └── trainval.txt 训练集+验证集(2913)
│
├── JPEGImages 所有图像文件
├── SegmentationClass 语义分割png图(基于类别)
└── SegmentationObject 实例分割png图(基于目标)
在实例分割任务中,我们需要用到Annotation
文件夹下的XML
标注信息,还需要使用到 ImageSets->Segmentation
下的train.txt以及val.txt, 这两个文件记录训练和验证时所需采用的图像名称(名称没有包含后缀)。还需要使用到JPEGImages
所有需要使用到的训练图片。最后还需要用到实例分割的png图片SegmentationObject
, 参考博文:
参考博文: PASCAL VOC2012数据集介绍
在标注图片中,每个目标的像素究竟代表的是什么含义?
,下图可以很好的解释:
首先,需要看下原图对应的xml
标注文件,在xml文件中给出了目标的信息
,以上图的标注文件为例,总共标注了4个目标(目标1,目标2,目标3,目标4)。根据目标1
它的boundingbox信息,可以知道它对应的是分割图片上的红色目标,红色目标它的像素值都是1,刚好和xml中标注每个目标的顺序保持一致的
。同理第二个目标小飞机,对应分割区域的像素值都是为2的,同理目标3,目标4也是这样。
2.2 VOCInstances类
相关代码在my_dataset_voc.py
中,定义了一个VOCInstances类
__init__
方法
- 传入参数:
voc_root
,year
,txt_name
,transform
,year代表voc数据的年份,txt_name指的是segmentation下的train.txt或者val.txt - 获取
image_dir,xml_dir,mask_dir
,分别对应VOC2012下的JPEGImages,Annotations,Segmentationobject
文件夹的路径,并assert判断文件是否存在,不存在则报错提醒。 - 获取所有images的路径images_url,所有的xmls的路径,所有mask的路径。
# 检查图片、xml文件以及mask是否都在
images_path = [os.path.join(image_dir, x + ".jpg") for x in file_names]
xmls_path = [os.path.join(xml_dir, x + '.xml') for x in file_names]
masks_path = [os.path.join(mask_dir, x + ".png") for x in file_names]
1) 解析xml中的box信息
- 读取
xml
文件为字符串,然后利用etree.fromstring
转换为xml的Element 对象。并通过parse_xml_to_dict
转换为dict。
with open(xml_path) as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
obs_dict = parse_xml_to_dict(xml)["annotation"] # 将xml文件解析成字典
其中关于 parse_xml_to_dict
的代码如下:
def parse_xml_to_dict(xml):
if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = parse_xml_to_dict(child) # 递归遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
- 接下来通过
parse_objects
方法,将解析得到字典传入给parse_objects
,解析出字典当中每个目标的boundingbox。
def parse_objects(data: dict, xml_path: str, class_dict: dict, idx: int):
"""
解析出bboxes、labels、iscrowd以及ares等信息
Args:
data: 将xml解析成dict的Annotation数据
xml_path: 对应xml的文件路径
class_dict: 类别与索引对应关系
idx: 图片对应的索引
Returns:
"""
boxes = []
labels = []
iscrowd = []
assert "object" in data, "{} lack of object information.".format(xml_path)
for obj in data["object"]:
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])
# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
continue
boxes.append([xmin, ymin, xmax, ymax])
labels.append(int(class_dict[obj["name"]]))
if "difficult" in obj:
iscrowd.append(int(obj["difficult"]))
else:
iscrowd.append(0)
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
return {"boxes": boxes,
"labels": labels,
"iscrowd": iscrowd,
"image_id": image_id,
"area": area}
- 遍历object获得box信息,并判断
xmax <= xmin or ymax <= ymin:
是否为true,如果为true说明获得的box面积为负,需要跳过该目标。如果box面积不为负,则继续获取object的label,iscrowd,area等信息。 - 最后将这张图片获取的
labels,iscrowd,image_id,area
打包为一个字典,并返回。
2) 获取mask信息
- 计算该张图片的目标数:
obs_bboxes["boxes"].shape[0]
- 利用
Image.open
读取分割图片,以P
格式(调色板)模式读取的,读取的是一张单通
道的图片,并转换为numpy格式。将mask中像素值为255
的区域忽略掉
(255区域为背景或者难分割的区域)
# 读取SegmentationObject并检查是否和bboxes信息数量一致
instances_mask = Image.open(mask_path)
instances_mask = np.array(instances_mask)
instances_mask[instances_mask == 255] = 0 # 255为背景或者忽略掉的地方,这里为了方便直接设置为背景(0)
需要检查一下标注的bbox个数是否和instances个数一致
。前文介绍过分割图中每个目标的像素值等于xml中每个目标的顺序。也就是说有多少个目标,分割图片中最大的像素就是多少。- 如果xml目标数和instances的分割目标数不一样就跳过,如果一样的话,就保存mask信息。
num_instances = instances_mask.max()
if num_objs != num_instances:
print(f"warning: num_boxes:{num_objs} and num_instances:{num_instances} do not correspond. "
f"skip image:{img_path}")
continue
self.images_path.append(img_path)
self.xmls_path.append(xml_path)
self.xmls_info.append(obs_dict)
self.masks_path.append(mask_path)
self.objects_bboxes.append(obs_bboxes)
self.masks.append(instances_mask)
2.3 __getitem__
方法
- 根据
index
索引,获得图片路径,并通过Image.open
去读取它,并转换为RGB
3通道。 - 获得该图片的target信息,它有bbox组成,注意没有包括mask信息。然后将该图片的所有mask信息也加入到target中。
- 如果transoform不为None的话,就进行数据增强,并返回img和target。
def __getitem__(self, idx):
"""
Args:
idx (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images_path[idx]).convert('RGB')
target = self.objects_bboxes[idx]
masks = self.parse_mask(idx)
target["masks"] = masks
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
其中parse_mask的代码如下,将mask中不同目标单独保存为1个channel存放
def parse_mask(self, idx: int):
mask = self.masks[idx]
c = mask.max() # 有几个目标最大索引就等于几
masks = []
# 对每个目标的mask单独使用一个channel存放
for i in range(1, c+1):
masks.append(mask == i)
masks = np.stack(masks, axis=0)
return torch.as_tensor(masks, dtype=torch.uint8)
参考
1. PASCAL VOC2012数据集介绍
2. MS COCO数据集介绍以及pycocotools简单使用