Mask-RCNN(3) : 自定义数据集读取(VOC COCO)以及pycocotools的使用

news2024/11/29 2:44:39

文章目录

    • 1. COCO数据读取
      • 1.1 COCO数据集目录结构
      • 1.2 pycocotools的使用
      • 1.3 COCODetection类
        • `__init__`方法
        • `__getitem__`方法
    • 2. VOC数据读取
      • 2.1 VOC数据集目录结构
      • 2.2 VOCInstances类
        • `__init__`方法
        • 2.3 `__getitem__`方法
    • 参考

1. COCO数据读取

1.1 COCO数据集目录结构

下载并解压COCO数据集,可得到如下文件夹结构:

├── coco2017: 数据集根目录
     ├── train2017: 所有训练图像文件夹(118287)
     ├── val2017: 所有验证图像文件夹(5000)
     └── annotations: 对应标注文件夹
              ├── instances_train2017.json: 对应目标检测、分割任务的训练集标注文件
              ├── instances_val2017.json: 对应目标检测、分割任务的验证集标注文件
              ├── captions_train2017.json: 对应图像描述的训练集标注文件
              ├── captions_val2017.json: 对应图像描述的验证集标注文件
              ├── person_keypoints_train2017.json: 对应人体关键点检测的训练集标注文件
              └── person_keypoints_val2017.json: 对应人体关键点检测的验证集标注文件夹
  • 在实例分割任务中,train2017保存所有训练图像
  • val2017保存所有的验证图像
  • 对于标注文件,只使用到了 instances_train2017.jsoninstances_val2017.json两个标签文件。

1.2 pycocotools的使用

参考博文:MS COCO数据集介绍以及pycocotools简单使用

官方有给出一个读取MS COCO数据集信息的API(当然,该API还有其他重要功能),下面是对应github的连接,里面有关于该API的使用demo:
https://github.com/cocodataset/cocoapi

读取每张图片的bbox信息
下面是使用pycocotools读取图像以及对应bbox信息的简单示例:

import os
from pycocotools.coco import COCO
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

json_path = "/data/coco2017/annotations/instances_val2017.json"
img_path = "/data/coco2017/val2017"

# load coco data
coco = COCO(annotation_file=json_path)

# get all image index info
ids = list(sorted(coco.imgs.keys()))
print("number of images: {}".format(len(ids)))

# get all coco class labels
coco_classes = dict([(v["id"], v["name"]) for k, v in coco.cats.items()])

# 遍历前三张图像
for img_id in ids[:3]:
    # 获取对应图像id的所有annotations idx信息
    ann_ids = coco.getAnnIds(imgIds=img_id)

    # 根据annotations idx信息获取所有标注信息
    targets = coco.loadAnns(ann_ids)

    # get image file name
    path = coco.loadImgs(img_id)[0]['file_name']

    # read image
    img = Image.open(os.path.join(img_path, path)).convert('RGB')
    draw = ImageDraw.Draw(img)
    # draw box to image
    for target in targets:
        x, y, w, h = target["bbox"]
        x1, y1, x2, y2 = x, y, int(x + w), int(y + h)
        draw.rectangle((x1, y1, x2, y2))
        draw.text((x1, y1), coco_classes[target["category_id"]])

    # show image
    plt.imshow(img)
    plt.show()

通过pycocotools读取的图像以及对应的targets信息,配合matplotlib库绘制标注图像如下:
在这里插入图片描述
读取每张图像的segmentation信息
下面是使用pycocotools读取图像segmentation信息的简单示例:


import os
import random

import numpy as np
from pycocotools.coco import COCO
from pycocotools import mask as coco_mask
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

random.seed(0)

json_path = "/data/coco2017/annotations/instances_val2017.json"
img_path = "/data/coco2017/val2017"

# random pallette
pallette = [0, 0, 0] + [random.randint(0, 255) for _ in range(255*3)]

# load coco data
coco = COCO(annotation_file=json_path)

# get all image index info
ids = list(sorted(coco.imgs.keys()))
print("number of images: {}".format(len(ids)))

# get all coco class labels
coco_classes = dict([(v["id"], v["name"]) for k, v in coco.cats.items()])

# 遍历前三张图像
for img_id in ids[:3]:
    # 获取对应图像id的所有annotations idx信息
    ann_ids = coco.getAnnIds(imgIds=img_id)
    # 根据annotations idx信息获取所有标注信息
    targets = coco.loadAnns(ann_ids)

    # get image file name
    path = coco.loadImgs(img_id)[0]['file_name']
    # read image
    img = Image.open(os.path.join(img_path, path)).convert('RGB')
    img_w, img_h = img.size

    masks = []
    cats = []
    for target in targets:
        cats.append(target["category_id"])  # get object class id
        polygons = target["segmentation"]   # get object polygons
        rles = coco_mask.frPyObjects(polygons, img_h, img_w)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = mask.any(axis=2)
        masks.append(mask)

    cats = np.array(cats, dtype=np.int32)
    if masks:
        masks = np.stack(masks, axis=0)
    else:
        masks = np.zeros((0, height, width), dtype=np.uint8)

    # merge all instance masks into a single segmentation map
    # with its corresponding categories
    target = (masks * cats[:, None, None]).max(axis=0)
    # discard overlapping instances
    target[masks.sum(0) > 1] = 255
    target = Image.fromarray(target.astype(np.uint8))

    target.putpalette(pallette)
    plt.imshow(target)
    plt.show()

通过pycocotools读取的图像segmentation信息,配合matplotlib库绘制标注图像如下:

在这里插入图片描述
读取人体关键点信息
在MS COCO任务中,对每个人体都标注了17的关键点,这17个关键点的部位分别如下:

["nose","left_eye","right_eye","left_ear","right_ear","left_shoulder","right_shoulder","left_elbow","right_elbow","left_wrist","right_wrist","left_hip","right_hip","left_knee","right_knee","left_ankle","right_ankle"]

在COCO给出的标注文件中,针对每个人体的标注格式如下所示。其中每3个值为一个关键点的相关信息,因为有17个关键点所以总共有51个数值。按照3个一组进行划分,前2个值代表关键点的x,y坐标,第3个值代表该关键点的可见度,它只会取{ 0 , 1 , 2 } 三个值。0表示该点一般是在图像外无法标注,1表示虽然该点不可见但大概能猜测出位置(比如人侧着站时虽然有一只耳朵被挡住了,但大概也能猜出位置),2表示该点可见。如果第3个值为0,那么对应的x,y也都等于0:

下面是使用pycocotools读取图像keypoints信息的简单示例:

import numpy as np
from pycocotools.coco import COCO

json_path = "/data/coco2017/annotations/person_keypoints_val2017.json"
coco = COCO(json_path)
img_ids = list(sorted(coco.imgs.keys()))

# 遍历前5张图片中的人体关键点信息(注意,并不是每张图片里都有人体信息)
for img_id in img_ids[:5]:
    idx = 0
    img_info = coco.loadImgs(img_id)[0]
    ann_ids = coco.getAnnIds(imgIds=img_id)
    anns = coco.loadAnns(ann_ids)
    for ann in anns:
        xmin, ymin, w, h = ann['bbox']
        # 打印人体bbox信息
        print(f"[image id: {img_id}] person {idx} bbox: [{xmin:.2f}, {ymin:.2f}, {xmin + w:.2f}, {ymin + h:.2f}]")
        keypoints_info = np.array(ann["keypoints"]).reshape([-1, 3])
        visible = keypoints_info[:, 2]
        keypoints = keypoints_info[:, :2]
        # 打印关键点信息以及可见度信息
        print(f"[image id: {img_id}] person {idx} keypoints: {keypoints.tolist()}")
        print(f"[image id: {img_id}] person {idx} keypoints visible: {visible.tolist()}")
        idx += 1

终端输出信息如下,通过以下信息可知,验证集中前5张图片里只有一张图片包含人体关键点信息:

[image id: 139] person 0 bbox: [412.80, 157.61, 465.85, 295.62]
[image id: 139] person 0 keypoints: [[427, 170], [429, 169], [0, 0], [434, 168], [0, 0], [441, 177], [446, 177], [437, 200], [430, 206], [430, 220], [420, 215], [445, 226], [452, 223], [447, 260], [454, 257], [455, 290], [459, 286]]
[image id: 139] person 0 keypoints visible: [1, 2, 0, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
[image id: 139] person 1 bbox: [384.43, 172.21, 399.55, 207.95]
[image id: 139] person 1 keypoints: [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]
[image id: 139] person 1 keypoints visible: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

1.3 COCODetection类

__init__方法

 def __init__(self, root, dataset="train", transforms=None, years="2017"):
        super(CocoDetection, self).__init__()
        assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
        anno_file = f"instances_{dataset}{years}.json"
        assert os.path.exists(root), "file '{}' does not exist.".format(root)
        self.img_root = os.path.join(root, f"{dataset}{years}")
        assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
        self.anno_path = os.path.join(root, "annotations", anno_file)
        assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)

        self.mode = dataset
        self.transforms = transforms
        self.coco = COCO(self.anno_path)

        # 获取coco数据索引与类别名称的关系
        # 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的
        data_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
        max_index = max(data_classes.keys())  # 90
        # 将缺失的类别名称设置成N/A
        coco_classes = {}
        for k in range(1, max_index + 1):
            if k in data_classes:
                coco_classes[k] = data_classes[k]
            else:
                coco_classes[k] = "N/A"

        if dataset == "train":
            json_str = json.dumps(coco_classes, indent=4)
            with open("coco91_indices.json", "w") as f:
                f.write(json_str)

        self.coco_classes = coco_classes

        ids = list(sorted(self.coco.imgs.keys()))
        if dataset == "train":
            # 移除没有目标,或者目标面积非常小的数据
            valid_ids = coco_remove_images_without_annotations(self.coco, ids)
            self.ids = valid_ids
        else:
            self.ids = ids

  • root:解压后COCO2017文件夹的根目录,dataset: 指定读取训练集(train)还是验证集(val);trainsform: 指定的数据增强方式;years: coco数据集的年份,默认设置为2017
  • 构建实例分割标注文件的名称,对应annotations文件夹下的instance_train2017.jsoninstance_val2017.json ,所以我们可以根据传入dataset对应train还是val,构建标注文件名。
 anno_file = f"instances_{dataset}{years}.json"
  • 构建img_root,anno_path, 并assert判断是否存在,如果不存在则报错提示。
  • 实例化COCO类,传入标注文件路径anno_path
  • 接下来获取coco类的索引与类的关系,在COCO数据集中它的类别是按照stuff91类别来排序的,通过遍历self.coco.cats.itmes()获取类别与索引的关系。调用max_index = max(data_classes.keys())可以看到最大的索引为90data_classes.keys()可以看出类别的索引并不是连续的,11直接到13,25直接到27。
dict_keys([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90])
  • 接下来就重新构建了索引与类别的关系,将1,90索引中缺失的类别设置为N/A
for k in range(1, max_index + 1):
      if k in data_classes:
          coco_classes[k] = data_classes[k]
      else:
          coco_classes[k] = "N/A"

并将coco_classs键值对,写入到coco91_indices.json文件中

  • 接下来保存所有图片的id,并移除图片中没有目标,或者面积非常小的图片
def coco_remove_images_without_annotations(dataset, ids):
    """
    删除coco数据集中没有目标,或者目标面积非常小的数据
    refer to:
    https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
    :param dataset:
    :param cat_list:
    :return:
    """
    // 长宽<1
    def _has_only_empty_bbox(anno):
        return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

    def _has_valid_annotation(anno):
        # if it's empty, there is no annotation
        if len(anno) == 0:
            return False
        # if all boxes have close to zero area, there is no annotation
        if _has_only_empty_bbox(anno):
            return False

        return True

    valid_ids = []
    for ds_idx, img_id in enumerate(ids):
        ann_ids = dataset.getAnnIds(imgIds=img_id, iscrowd=None)
        anno = dataset.loadAnns(ann_ids)

        if _has_valid_annotation(anno):
            valid_ids.append(img_id)

    return valid_ids

__getitem__方法

  • 根据传入的索引index,获得img_id,根据img_id获取当前图片对应的ann_ids, 根据anno_ids获得该张图片的所有标注信息coco_target。设置断点查看获取的coco_target数据。
    在这里插入图片描述
  • 首先coco_target是个list列表,每个元素是一个dict字典,每个字典包含有:segmentation,area,iscrowd,image_id,is_crowed等,其中is_crowed为0表示单独的目标,为1的话代表可能有多个目标重叠在一起。
  • 这里需要关注segmentation的信息,coco数据集对每个分割目标,通过一个个点标注出分割的多边形轮廓。后面需要一个个点组成的多边形轮廓转换为我们所需要的mask信息。segmentation标注的一个个点的信息示例如下:
    在这里插入图片描述
  • 接下来读取图片转换为RGB,并利用parse_targets将coco_target转换为我们需要的target输出。
 def parse_targets(self,
                      img_id: int,
                      coco_targets: list,
                      w: int = None,
                      h: int = None):
        assert w > 0
        assert h > 0

        # 只筛选出单个对象的情况
        anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]

        boxes = [obj["bbox"] for obj in anno]

        # guard against no boxes via resizing
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        # [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
        boxes[:, 2:] += boxes[:, :2]
        boxes[:, 0::2].clamp_(min=0, max=w)
        boxes[:, 1::2].clamp_(min=0, max=h)

        classes = [obj["category_id"] for obj in anno]
        classes = torch.tensor(classes, dtype=torch.int64)

        area = torch.tensor([obj["area"] for obj in anno])
        iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])

        segmentations = [obj["segmentation"] for obj in anno]
        masks = convert_coco_poly_mask(segmentations, h, w)

        # 筛选出合法的目标,即x_max>x_min且y_max>y_min
        keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
        boxes = boxes[keep]
        classes = classes[keep]
        masks = masks[keep]
        area = area[keep]
        iscrowd = iscrowd[keep]

        target = {}
        target["boxes"] = boxes
        target["labels"] = classes
        target["masks"] = masks
        target["image_id"] = torch.tensor([img_id])

        # for conversion to coco api
        target["area"] = area
        target["iscrowd"] = iscrowd

        return target
  • 只保留iscrowd=0,也就是目标之间没有重叠的。iscrwed=1,一般是比较难处理
  • 将输出的boxes的每个bdbox的形式由[xmin,ymin,w,h],转换为[xmin,ymin,xmax,ymax],然后利用clamp将x坐标限制在0,w之间,将y坐标限制在0,h之间
 boxes[:, 2:] += boxes[:, :2]
 boxes[:, 0::2].clamp_(min=0, max=w)
 boxes[:, 1::2].clamp_(min=0, max=h)

构建mask

  • 然后获取标注文件的目标的类别信息classes,iscrowed信息,segmentation信息。然后将segmentation信息转换为masks。
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)

area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])

segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_mask(segmentations, h, w)
def convert_coco_poly_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        # 如果mask为空,则说明没有目标,直接返回数值为0的mask
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks
  • 首先通过遍历segmentations,获取每个目标的多边形信息
  • 然后通过coco_mask将多边形坐标转换为mask, 使用这个方法需要导入from pycocotools import mask as coco_mask
for polygons in segmentations:
       rles = coco_mask.frPyObjects(polygons, height, width)
       mask = coco_mask.decode(rles)
  • 在mask出打断点,可以看出其shape大小为(480,640,1), 查看mask0通道的信息make[:,:,0],可以看到这里的mask蒙版,背景对应的是0前景对应的是1填充。如果mark的shape小于3,也就是说只有高宽没有通道信息,就添加一个channel信息。mask =mask[...,None]
  • 通道方向只要有一个值为1,就认为是前景 mask =mask.any(dim=2), 然后将所有mask进行stack拼接,并返回。如果没有mask,直接使用torch.zeros构建一个全是背景的masks
def convert_coco_poly_mask(segmentations, height, width):
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)
        masks.append(mask)
    if masks:
        masks = torch.stack(masks, dim=0)
    else:
        # 如果mask为空,则说明没有目标,直接返回数值为0的mask
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks
  • 然后筛选合适的targes,即x_max > x_min并且y_max > y_min, 然后针对这些合法的目标提取出boxes,classes,masks,area,iscrowed,并保存到targe字典中进行返回。
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
masks = masks[keep]
area = area[keep]
iscrowd = iscrowd[keep]

target = {}
target["boxes"] = boxes
target["labels"] = classes
target["masks"] = masks
target["image_id"] = torch.tensor([img_id])

# for conversion to coco api
target["area"] = area
target["iscrowd"] = iscrowd

最后__getitem__返回image,以及我们构建好的target信息。

2. VOC数据读取

2.1 VOC数据集目录结构

下载VOC2012数据集,并解压得到如下的文件目录结构

VOCdevkit
    └── VOC2012
         ├── Annotations               所有的图像标注信息(XML文件)
         ├── ImageSets
         │   ├── Action                人的行为动作图像信息
         │   ├── Layout                人的各个部位图像信息
         │   │
         │   ├── Main                  目标检测分类图像信息
         │   │     ├── train.txt       训练集(5717)
         │   │     ├── val.txt         验证集(5823)
         │   │     └── trainval.txt    训练集+验证集(11540)
         │   │
         │   └── Segmentation          目标分割图像信息
         │         ├── train.txt       训练集(1464)
         │         ├── val.txt         验证集(1449)
         │         └── trainval.txt    训练集+验证集(2913)
         │
         ├── JPEGImages                所有图像文件
         ├── SegmentationClass         语义分割png图(基于类别)
         └── SegmentationObject        实例分割png图(基于目标)

在实例分割任务中,我们需要用到Annotation文件夹下的XML标注信息,还需要使用到 ImageSets->Segmentation下的train.txt以及val.txt, 这两个文件记录训练和验证时所需采用的图像名称(名称没有包含后缀)。还需要使用到JPEGImages所有需要使用到的训练图片。最后还需要用到实例分割的png图片SegmentationObject, 参考博文:

参考博文: PASCAL VOC2012数据集介绍

在这里插入图片描述

图1 原图

在这里插入图片描述

图2 分割图

在标注图片中,每个目标的像素究竟代表的是什么含义?,下图可以很好的解释:
在这里插入图片描述
首先,需要看下原图对应的xml标注文件,在xml文件中给出了目标的信息,以上图的标注文件为例,总共标注了4个目标(目标1,目标2,目标3,目标4)。根据目标1它的boundingbox信息,可以知道它对应的是分割图片上的红色目标,红色目标它的像素值都是1,刚好和xml中标注每个目标的顺序保持一致的。同理第二个目标小飞机,对应分割区域的像素值都是为2的,同理目标3,目标4也是这样。

2.2 VOCInstances类

相关代码在my_dataset_voc.py中,定义了一个VOCInstances类

__init__方法

  • 传入参数: voc_root, year,txt_name,transform,year代表voc数据的年份,txt_name指的是segmentation下的train.txt或者val.txt
  • 获取image_dir,xml_dir,mask_dir,分别对应VOC2012下的JPEGImages,Annotations,Segmentationobject文件夹的路径,并assert判断文件是否存在,不存在则报错提醒。
  • 获取所有images的路径images_url,所有的xmls的路径,所有mask的路径。
# 检查图片、xml文件以及mask是否都在
 images_path = [os.path.join(image_dir, x + ".jpg") for x in file_names]
 xmls_path = [os.path.join(xml_dir, x + '.xml') for x in file_names]
 masks_path = [os.path.join(mask_dir, x + ".png") for x in file_names]

1) 解析xml中的box信息

  • 读取xml文件为字符串,然后利用etree.fromstring转换为xml的Element 对象。并通过 parse_xml_to_dict转换为dict。
with open(xml_path) as fid:
      xml_str = fid.read()
      xml = etree.fromstring(xml_str)
      obs_dict = parse_xml_to_dict(xml)["annotation"]  # 将xml文件解析成字典

其中关于 parse_xml_to_dict的代码如下:

def parse_xml_to_dict(xml):
    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        return {xml.tag: xml.text}

    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}
  • 接下来通过parse_objects方法,将解析得到字典传入给parse_objects,解析出字典当中每个目标的boundingbox。
def parse_objects(data: dict, xml_path: str, class_dict: dict, idx: int):
    """
    解析出bboxes、labels、iscrowd以及ares等信息
    Args:
        data: 将xml解析成dict的Annotation数据
        xml_path: 对应xml的文件路径
        class_dict: 类别与索引对应关系
        idx: 图片对应的索引

    Returns:

    """
    boxes = []
    labels = []
    iscrowd = []
    assert "object" in data, "{} lack of object information.".format(xml_path)
    for obj in data["object"]:
        xmin = float(obj["bndbox"]["xmin"])
        xmax = float(obj["bndbox"]["xmax"])
        ymin = float(obj["bndbox"]["ymin"])
        ymax = float(obj["bndbox"]["ymax"])

        # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
        if xmax <= xmin or ymax <= ymin:
            print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
            continue

        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(int(class_dict[obj["name"]]))
        if "difficult" in obj:
            iscrowd.append(int(obj["difficult"]))
        else:
            iscrowd.append(0)

    # convert everything into a torch.Tensor
    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    labels = torch.as_tensor(labels, dtype=torch.int64)
    iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
    image_id = torch.tensor([idx])
    area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

    return {"boxes": boxes,
            "labels": labels,
            "iscrowd": iscrowd,
            "image_id": image_id,
            "area": area}

  • 遍历object获得box信息,并判断xmax <= xmin or ymax <= ymin:是否为true,如果为true说明获得的box面积为负,需要跳过该目标。如果box面积不为负,则继续获取object的label,iscrowd,area等信息。
  • 最后将这张图片获取的labels,iscrowd,image_id,area打包为一个字典,并返回。

2) 获取mask信息

  • 计算该张图片的目标数:obs_bboxes["boxes"].shape[0]
  • 利用Image.open读取分割图片,以P格式(调色板)模式读取的,读取的是一张单通道的图片,并转换为numpy格式。将mask中像素值为255的区域忽略掉(255区域为背景或者难分割的区域)
# 读取SegmentationObject并检查是否和bboxes信息数量一致
instances_mask = Image.open(mask_path)
instances_mask = np.array(instances_mask)
instances_mask[instances_mask == 255] = 0  # 255为背景或者忽略掉的地方,这里为了方便直接设置为背景(0)
  • 需要检查一下标注的bbox个数是否和instances个数一致。前文介绍过分割图中每个目标的像素值等于xml中每个目标的顺序。也就是说有多少个目标,分割图片中最大的像素就是多少。
  • 如果xml目标数和instances的分割目标数不一样就跳过,如果一样的话,就保存mask信息。
 num_instances = instances_mask.max()
 if num_objs != num_instances:
     print(f"warning: num_boxes:{num_objs} and num_instances:{num_instances} do not correspond. "
           f"skip image:{img_path}")
     continue

 self.images_path.append(img_path)
 self.xmls_path.append(xml_path)
 self.xmls_info.append(obs_dict)
 self.masks_path.append(mask_path)
 self.objects_bboxes.append(obs_bboxes)
 self.masks.append(instances_mask)

2.3 __getitem__方法

  • 根据index索引,获得图片路径,并通过Image.open去读取它,并转换为RGB3通道。
  • 获得该图片的target信息,它有bbox组成,注意没有包括mask信息。然后将该图片的所有mask信息也加入到target中。
  • 如果transoform不为None的话,就进行数据增强,并返回img和target。
def __getitem__(self, idx):
        """
        Args:
            idx (int): Index

        Returns:
            tuple: (image, target) where target is the image segmentation.
        """
        img = Image.open(self.images_path[idx]).convert('RGB')
        target = self.objects_bboxes[idx]
        masks = self.parse_mask(idx)
        target["masks"] = masks

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

其中parse_mask的代码如下,将mask中不同目标单独保存为1个channel存放

def parse_mask(self, idx: int):
     mask = self.masks[idx]
     c = mask.max()  # 有几个目标最大索引就等于几
     masks = []
     # 对每个目标的mask单独使用一个channel存放
     for i in range(1, c+1):
         masks.append(mask == i)
     masks = np.stack(masks, axis=0)
     return torch.as_tensor(masks, dtype=torch.uint8)

参考

1. PASCAL VOC2012数据集介绍
2. MS COCO数据集介绍以及pycocotools简单使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/344892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL 6:MySQL存储过程、存储函数

MySQL 5.0 版本开始支持存储过程。存储过程是一组SQL语句&#xff0c;功能强大&#xff0c;可以实现一些复杂的逻辑功能&#xff0c;类似于JAVA语言中的方法&#xff1b;存储是数据库SQL语言层面的代码封装和复用。 存储过程有输入输出参数&#xff0c;可以声明变量&#xff0…

Android Monkey

1、Monkey&#xff08;Monkey是发送伪随机用户事件的工具&#xff09;介绍&#xff1a; Monkey测试是Android平台自动化测试的一种手段&#xff0c;通过Monkey程序模拟用户触摸屏幕、滑动Trackball、按键等操作来对设备上的程序进行压力测试&#xff0c;检测程序多久的时间会发…

安灯(andon)系统是车间现场管理的必备工具

安灯&#xff08;andon&#xff09;系统应用越来越广泛&#xff0c;不单单局限于汽车行业&#xff0c;更多生产型企业意识到了提高工作效率的重要性&#xff0c;提高工作效率根本的能提高生产水平&#xff0c;提高产量&#xff0c;而且安灯&#xff08;andon&#xff09;系统不…

python(16)--类

一、类的基本操作1.定义一个类格式&#xff1a;class Classname( )&#xff1a;内容&#x1f48e;鄙人目前还是一名学生&#xff0c;最熟悉的也就是学校了&#xff0c;所以就以学校为例子来建立一个类吧class School():headline"帝国理工大学"def schoolmotto(self):…

java 代码块 万字详解

概述 : 特点 : 格式 : 情景 : 细节 : 演示 : 英文 : //v&#xff0c;新版编辑器无手动添加目录的功能&#xff0c;PC端阅读建议通过侧边栏进行目录跳转&#xff1b;移动端建议用PC端阅读。&#x1f602;一、概述 :代码块&#xff0c;也称为初始化块&#xff0c;属于类中的成员&…

Vue3 如何实现一个带遮罩的 dialog 对话框

theme: mk-cute 开启掘金成长之旅&#xff01;这是我参与「掘金日新计划 12 月更文挑战」的第7天&#xff0c;点击查看活动详情 前言&#xff1a; 今天在项目中遇到了很多很多需要弹出一个对话框的场景&#xff0c;由于之前全都是通过 v-if 来控制这个组件的显示与否&#x…

【python游戏】让我们一起制作地球联邦阵营的战机,保护希望水晶,为人类的希望而战。

前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! 随着人类太空科技的飞速发展&#xff0c;希望水晶被越来越多的科学家当做核心能源来开发使用。 人类社会也因为水晶资源的争夺&#xff0c;开始逐渐分化成两派。 留在地球的普通人成立地球联邦&#xff0c;移居卫星的新人…

我对平衡二叉树的理解(比喻的方式)

传销是一种恶性的行销方式&#xff0c;主要手段就是激励其中的成员拉人头。 有个奇怪的传销组织&#xff0c;他们的传销规则是这样的&#xff1a; 每个人最多可以带着2人进该组织&#xff0c;其中1个年纪比自己大&#xff0c;另1个年纪比自己小新人都是由创始人找到。假如年纪…

中文关键词提取算法

中文关键词提取算法 如何提取query或者文档的关键词&#xff1f; 一般有两种解决思路&#xff1a; 有监督方法&#xff0c;把关键词提取问题当做分类问题&#xff0c;文本分词后标记各词的重要性打分&#xff0c;然后挑出重要的topK个词&#xff1b;无监督方法&#xff0c;使…

likeshop单商户SaaS版V1.8.2说明!

likeshop单商户SaaS版V1.8.2主要更新如下&#xff1a; 新增 前端登录引导用户填写头像昵称 PC端—注册页面显示服务协议和隐私政策 PC端—首次进入商城弹出协议提示 PC端—结算页新增门店自提的配送方式 后台—PC端菜单导航栏的跳转链接支持添加自定义链接 ​​ ​​ ​ 优…

2022年“网络安全”赛项宜昌市选拔赛 任务书

2022年“网络安全”赛项宜昌市选拔赛 任务书 任务书 一、竞赛时间 共计3小时。 二、竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 第一阶段单兵模式系统渗透测试 任务一 数据库服务渗透测试 任务二 Wireshark数据包分析 任务三 Windows操作系统渗透测试 任务四 系统漏…

腾讯云企业网盘正式入驻数字工具箱

腾讯技术公益继腾讯电子签等入驻后&#xff0c;上线近半年的腾讯技术公益数字工具箱再次迎来新成员——腾讯云企业网盘&#xff0c;现已正式接受公益机构申请公益权益。腾讯云企业网盘&#xff08;https://pan.tencent.com&#xff09;是由腾讯云推出的一款安全、高效、开放的企…

python+flask开发mock服务

目录 什么是mock&#xff1f; 什么时候需要用到mock&#xff1f; 如何实现&#xff1f; pythonflask自定义mock服务的步骤 一、环境搭建 1、安装flask插件 2、验证插件 二、mock案例 1、模拟 返回结果 2、模拟 异常响应状态码 3、模拟登录&#xff0c;从jmeter中获取…

Kafka 消费者

与生产者对应的是消费者&#xff0c;应用程序可以通过 KafkaConsumer 来订阅主题&#xff0c;并从订阅主题中拉取消息。 消息者与消费组 消费者&#xff08;Consumer&#xff09;负责订阅 Kafka 中的主题&#xff08;Topic&#xff09;&#xff0c;并且从订阅的主题上拉取消息…

低代码开发平台|制造管理-生产过程管理搭建指南

1、简介1.1、案例简介本文将介绍&#xff0c;如何搭建制造管理-生产过程。1.2、应用场景先填充工序信息&#xff0c;再设置工艺路线对应的工序&#xff1b;工序信息及工艺路线列表报表展示的是所有工序、工艺路线信息&#xff0c;可进行新增对应数据的操作。2、设置方法2.1、表…

女生做大数据有发展前景吗?

当前大数据发展前景非常不错&#xff0c;且大数据领域对于人才类型的需求比较多元化&#xff0c;女生学习大数据也会有比较多的工作机会。大数据是一个交叉学科涉及到的知识量比较大学习有一定的难度&#xff0c;女生比较适合大数据采集和大数据分析方向的工作岗位。 大数据采…

【沁恒WCH CH32V307V-R1与Arduino的串口通讯】

【沁恒WCH CH32V307V-R1的单线半双工模式串口通讯】1. 前言2. 软件配置2.1 安装MounRiver Studio3. UASRT项目测试3.1 打开UASRT工程3.2 CH307串口发送数据到Arduino实验3.3 CH307串口接收数据Arduino实验5. 小结1. 前言 本例演示了采用CH307串口3与Arduino软串口收发通信&…

Python的深、浅拷贝到底是怎么回事?一篇解决问题

嗨害大家好鸭&#xff01;我是小熊猫~ 一、赋值 Python中&#xff0c; 对象的赋值都是进行对象引用&#xff08;内存地址&#xff09;传递, 赋值&#xff08;&#xff09;&#xff0c; 就是创建了对象的一个新的引用&#xff0c; 修改其中任意一个变量都会影响到另一个 will …

第七届蓝桥杯省赛——5分小组

题目&#xff1a;9名运动员参加比赛&#xff0c;需要分3组进行预赛。有哪些分组的方案呢&#xff1f;我们标记运动员为 A,B,C,... I下面的程序列出了所有的分组方法。该程序的正常输出为&#xff1a;ABC DEF GHIABC DEG FHIABC DEH FGIABC DEI FGHABC DFG EHIABC DFH EGIABC DF…

VFIO软件依赖——VFIO协议

文章目录背景PCI设备模拟PCI设备抽象VFIO协议实验Q&A背景 在虚拟化应用场景中&#xff0c;虚拟机想要在访问PCI设备时达到IO性能最优&#xff0c;最直接的方法就是将物理设备暴露给虚拟机&#xff0c;虚拟机对设备的访问不经过任何中间层的转换&#xff0c;没有虚拟化的损…