在对FCN/UNET/deeplabv3等语义分割时,标准的要求是对每一个像素点分开标记,即不允许出现重叠覆盖的情形:如下图所示
- 但不可避免的人工标注时会出现一定的标注重叠/重复/覆盖
- 甚至有的时候需要标注就是重复的,例如需要识别面板上赃物的情形,标记了面板和脏污,标注是重叠的,但是实际存在固定的上下覆盖叠加关系.可以通过一定的方法处理出这种关系.
下面以pytorch官方的实例代码为例
github链接: pytorch-segmentation
与之相关的就是transforms处理这一段:完整的处理流程如下
coco_utils.FilterAndRemapCocoCategories -> 过滤不要的标签,以及重新标记
coco_utils.ConvertCocoPolysToMask -> 多边形标记转mask标记,以及合标记,处理覆盖关系
presets.SegmentationPresetTrain -> 其他常见的预处理,数组增强等
transforms.RandomResize ->
transforms.RandomHorizontalFlip ->
transforms.RandomCrop ->
transforms.PILToTensor ->
transforms.ToDtype ->
transforms.Normalize ->
处理不忽略重复
有关重叠的就是ConvertCocoPolysToMask
有关重复的核心代码如下:
#将每个掩码与其对应的类别ID相乘,以便将类别信息合并到掩码中。
#沿着第一个维度(即对象实例的维度)取最大值。这样做的结果是,每个像素位置上的值将是其对应位置上的所有掩码和类别ID乘积中的最大值。由于每个位置通常只有一个掩码是活动的(值为1,其他为0),所以这将选择index最大的类别ID。
target, _ = (masks * cats[:, None, None]).max(dim=0)
# discard overlapping instances
# 如果多个掩码在相同的位置都为1(即存在重叠的实例),则将这些位置的像素值设置为255(通常表示“忽略”或“不关心”的类别)。
# masks.sum(0)计算了每个像素位置上掩码的和。如果和大于1,说明存在重叠。
target[masks.sum(0) > 1] = 255:
这两行中,第一行选择了index最大的类别ID,第二行将重复的置为了255
如果我们不需要忽略重叠,则直接注释掉第二行就好,这样大的index就可以覆盖小的
附:ConvertCocoPolysToMask类代码
class ConvertCocoPolysToMask:
def __call__(self, image, anno):
w, h = image.size
segmentations = [obj["segmentation"] for obj in anno]
cats = [obj["category_id"] for obj in anno]
if segmentations:
masks = convert_coco_poly_to_mask(segmentations, h, w)
cats = torch.as_tensor(cats, dtype=masks.dtype)
# merge all instance masks into a single segmentation map
# with its corresponding categories
target, _ = (masks * cats[:, None, None]).max(dim=0)
# discard overlapping instances
target[masks.sum(0) > 1] = 255
else:
target = torch.zeros((h, w), dtype=torch.uint8)
target = Image.fromarray(target.numpy())
return image, target
自定义的覆盖顺序
在FilterAndRemapCocoCategories中,可以重新映射index,以处理自定义的覆盖顺序
## 注意!!注意!!此参数意味着有多少标记会被识别
# 例: 在coco数据集中的category_id标记如为1,2,3,4, CAT_LIST 写为[3,4]
# 则将抛弃1,2的标记,然后将重新映射3->0,4->1. mask的信息就会写为0和1.
# 当使用不忽略重复时,还会影响到叠加关系,CAT_LIST后面的标记会覆盖前面的(即4会覆盖3),否则会设定为255
CAT_LIST = [1, 2]
coco_utils.FilterAndRemapCocoCategories(CAT_LIST, remap=True)
class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, anno):
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
return image, anno
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
return image, anno