使用coco数据集进行语义分割:数据预处理与损失函数

news2025/1/11 18:44:18

如何coco数据集进行目标检测的介绍已经有很多了,但是关于语义分割几乎没有。本文旨在说明如何处理 stuff_train2017.json    stuff_val2017.json    panoptic_train2017.json    panoptic_val2017.json,将上面那些json中的dict转化为图片的label mask,也就是制作图片像素的标签的ground truth。

首先下载图片和annotation文件(也就是json文件)

2017 Train Images和2017 Val Images解压得到这两个文件夹,里面放着的是所有图片。

annotation下载下来有4个json文件

以及两个压缩包

这两个压缩包里面是panoptic segmentation的mask,用于制作ground truth。

首先说明一下那四个json文件的含义,stuff是语义分割,panoptic是全景分割。二者的区别在与

语义分割只区别种类,全景分割还区分个体。上图中,中间的语义分割将所有的人类视作同一个种类看待,而右边的全景分割还将每一个个体区分出来。

做语义分割,可以使用stuff_train2017.json生成ground truth。但是即便是只做语义分割,不做全景分割,在只看语义分割的情况下,stuff的划分精度不如panoptic,因此建立即使是做语义分割也用panoptic_train2017.json。本文中将两种json都进行说明。

stuff_train2017.json

下面这段代码是处理stuff_train2017.json生成ground truth

from pycocotools.coco import COCO
import os
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt


def convert_coco2mask_show(image_id):
    print(image_id)
    img = coco.imgs[image_id]
    image = np.array(Image.open(os.path.join(img_dir, img['file_name'])))
    plt.imshow(image, interpolation='nearest')

    # cat_ids = coco.getCatIds()
    cat_ids = list(range(183))
    anns_ids = coco.getAnnIds(imgIds=img['id'], catIds=cat_ids, iscrowd=None)
    anns = coco.loadAnns(anns_ids)
    mask = np.zeros((image.shape[0], image.shape[1]))
    for i in range(len(anns)):
        tmp = coco.annToMask(anns[i]) * anns[i]["category_id"]
        mask += tmp
    
    print(np.max(mask), np.min(mask))
    
    # 绘制二维数组对应的颜色图
    plt.figure(figsize=(8, 6))
    plt.imshow(mask, cmap='viridis', vmin=0, vmax=182)
    plt.colorbar(ticks=np.linspace(0, 182, 6), label='Colors')  # 添加颜色条
    # plt.savefig(save_dir + str(image_id) + ".jpg")
    plt.show()


if __name__ == '__main__':
    Dataset_dir = "/home/xxxx/Downloads/coco2017/"
    coco = COCO("/home/xxxx/Downloads/coco2017/annotations/stuff_train2017.json")
    # createIndex
    # coco = COCO("/home/robotics/Downloads/coco2017/annotations/annotations/panoptic_val2017.json")
    img_dir = os.path.join(Dataset_dir, 'train2017')
    save_dir = os.path.join(Dataset_dir, "Mask/stuff mask/train")
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    image_id = 9
    convert_coco2mask_show(image_id)
    
    # for keyi, valuei in coco.imgs.items():
    #     image_id = valuei["id"]
    #     convert_coco2mask_show(image_id)

解释一下上面的代码。

coco = COCO("/home/xxxx/Downloads/coco2017/annotations/stuff_train2017.json")

从coco的官方库中导入工具,读取json。会得到这样的字典结构

在 coco.imgs ,找到 "id" 对应的value,这个就是每个图片唯一的编号,根据这个编号,到 coco.anns 中去索引 segmentation

注意 coco.imgs 的 id 对应的是 anns 中的 image_id,这个也是 train2017 文件夹中图片的文件名。

# cat_ids = coco.getCatIds()
cat_ids = list(range(183))

网上的绝大多数教程都写的是上面我注释的那行代码,那行代码会得到91个类,那个只能用于图像检测,但对于图像分割任务,包括的种类是182个和一个未归类的0类一共183个类。

tmp = coco.annToMask(anns[i]) * anns[i]["category_id"]

上面那行代码就是提取每一个类所对应的id,将所有的类都加进一个二维数组,就制作完成了ground truth。

panoptic_train2017.json

上面完成了通过stuff的json文件自作ground truth,那么用panoptic的json文件制作ground truth,是不是只要把

coco = COCO("/home/xxxx/Downloads/coco2017/annotations/stuff_train2017.json")

换为

coco = COCO("/home/xxxx/Downloads/coco2017/annotations/annotations/panoptic_val2017.json")

就行了呢?

答案是不行!会报错

Traceback (most recent call last):
  File "/home/xxxx/.local/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-66f5c6e7dbe4>", line 1, in <module>
    coco = COCO("/home/xxxx/Downloads/coco2017/annotations/panoptic_val2017.json")
  File "/home/robotics/.local/lib/python3.8/site-packages/pycocotools/coco.py", line 86, in __init__
    self.createIndex()
  File "/home/xxxx/.local/lib/python3.8/site-packages/pycocotools/coco.py", line 96, in createIndex
    anns[ann['id']] = ann
KeyError: 'id'

用官方的库读官方的json还报错,真是巨坑!关键还没有官方的文档教你怎么用,只能一点点摸索,非常浪费时间精力。

制作panoptic的ground truth,用到了Meta公司的detectron2这个库

# Copyright (c) Facebook, Inc. and its affiliates.
import copy
import json
import os

from detectron2.data import MetadataCatalog
from detectron2.utils.file_io import PathManager


def load_coco_panoptic_json(json_file, image_dir, gt_dir, meta):
    """
    Args:
        image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
        gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
        json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".

    Returns:
        list[dict]: a list of dicts in Detectron2 standard format. (See
        `Using Custom Datasets </tutorials/datasets.html>`_ )
    """

    def _convert_category_id(segment_info, meta):
        if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
            segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
                segment_info["category_id"]
            ]
            segment_info["isthing"] = True
        else:
            segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
                segment_info["category_id"]
            ]
            segment_info["isthing"] = False
        return segment_info

    with PathManager.open(json_file) as f:
        json_info = json.load(f)

    ret = []
    for ann in json_info["annotations"]:
        image_id = int(ann["image_id"])
        # TODO: currently we assume image and label has the same filename but
        # different extension, and images have extension ".jpg" for COCO. Need
        # to make image extension a user-provided argument if we extend this
        # function to support other COCO-like datasets.
        image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
        label_file = os.path.join(gt_dir, ann["file_name"])
        segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
        ret.append(
            {
                "file_name": image_file,
                "image_id": image_id,
                "pan_seg_file_name": label_file,
                "segments_info": segments_info,
            }
        )
    assert len(ret), f"No images found in {image_dir}!"
    assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
    assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
    return ret


if __name__ == "__main__":
    from detectron2.utils.logger import setup_logger
    from detectron2.utils.visualizer import Visualizer
    import detectron2.data.datasets  # noqa # add pre-defined metadata
    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np

    logger = setup_logger(name=__name__)
    meta = MetadataCatalog.get("coco_2017_train_panoptic")

    dicts = load_coco_panoptic_json("/home/xxxx/Downloads/coco2017/annotations/panoptic_train2017.json",
                                    "/home/xxxx/Downloads/coco2017/train2017",
                                    "/home/xxxx/Downloads/coco2017/Mask/panoptic mask/panoptic_train2017", meta.as_dict())
    logger.info("Done loading {} samples.".format(len(dicts)))

    dirname = "coco-data-vis"
    os.makedirs(dirname, exist_ok=True)
    new_dic = {}
    num_imgs_to_vis = 100
    for i, d in enumerate(dicts):
        img = np.array(Image.open(d["file_name"]))
        visualizer = Visualizer(img, metadata=meta)
        pan_seg, segments_info = visualizer.draw_dataset_dict(d)
        seg_cat = {0: 0}
        for segi in segments_info:
            seg_cat[segi["id"]] = segi["category_id"]
        
        mapped_seg = np.vectorize(seg_cat.get)(pan_seg)
        
        # 保存数组为txt文件
        save_name = "/home/xxxx/Downloads/coco2017/Mask/panoptic label/train/" + str(d["image_id"]) + ".txt"
        np.savetxt(save_name, mapped_seg, fmt='%i')
        
        new_dic[d["image_id"]] = {"image_name": d["file_name"], "label_name": save_name}
        
        # # 将numpy数组转换为PIL Image对象
        # img1 = Image.fromarray(np.uint8(mapped_seg))
        # # 缩放图片
        # img_resized = img1.resize((224, 224))
        # # 将PIL Image对象转换回numpy数组
        # mapped_seg_resized = np.array(img_resized)
        
        # # 创建一个新的图形
        # plt.figure(figsize=(6, 8))
        # # 使用imshow函数来显示数组,并使用cmap参数来指定颜色映射
        # plt.imshow(mapped_seg_resized, cmap='viridis')
        # # 显示图形
        # plt.show()
        
        # fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
        # # vis.save(fpath)
        
        if i + 1 >= num_imgs_to_vis:
            # 将字典转换为json字符串
            json_data = json.dumps(new_dic)
            
            # 将json字符串写入文件
            with open('/home/xxxx/Downloads/coco2017/data_for_train.json', 'w') as f:
                f.write(json_data)
            break
dicts = load_coco_panoptic_json("/home/xxxx/Downloads/coco2017/annotations/panoptic_train2017.json",
                                    "/home/xxxx/Downloads/coco2017/train2017",
                                    "/home/xxxx/Downloads/coco2017/Mask/panoptic mask/panoptic_train2017", meta.as_dict())

load_coco_panoptic_json的第三个参数,就是下载panoptic的annotations时,里面包含的两个压缩包。

上面这段代码,还需要修改一下detectron2库内部的文件

pan_seg, segments_info = visualizer.draw_dataset_dict(d)

进入 draw_dataset_dict这个函数内部,将两个中间结果拿出来

603行和604行是我自己加的代码。

这两个中间结果拿出来后,pan_seg是图片每个像素所属的种类,但是这个种类不是coco分类的那183个类

pan_seg中的数,要映射到segments_info中的 category_id,这个才是coco数据集所规定的183个类 。下图节选了183个中的前10个展示

mapped_seg = np.vectorize(seg_cat.get)(pan_seg)

上面这行代码就是完成pan_seg中的数,映射到segments_info中的 category_id。不要用两层的for循环,太低效了,numpy中有函数可以完成数组的映射。

这样就得到了语义分割的ground truth,也就是每个像素所属的种类。

通过以下代码训练网络

import argparse
import time
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.optim as optim
from config import get_config
from build import build_model, build_transform, build_criterion
from utils import load_pretrained
from prepare_data import get_dataloader


CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/my_model/")
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

def parse_option():
    parser = argparse.ArgumentParser("Swin Backbone", add_help=False)
    parser.add_argument("--cfg", type=str, default="configs/swin/swin_tiny_patch4_window7_224_22k.yaml")
    parser.add_argument("--pretrained", default="checkpoint/swin_tiny_patch4_window7_224_22k.pth")
    parser.add_argument("--device", default="cuda")
    args = parser.parse_args()
    
    config = get_config(args)
    
    return config


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    config = parse_option()
    model = build_model(config).to(device)
    load_pretrained(config, model)
    
    transform = build_transform(config)
    
    # img_addr = "/home/xxxx/Downloads/coco2017/train2017/000000000009.jpg"
    # img = Image.open(img_addr)
    # img = transform(img)
    # # img = img.numpy().transpose(1, 2, 0)
    # # plt.imshow(img)
    # # plt.show()
    # img = img.unsqueeze(0).cuda()
    # output = model(img)
    
    train_loader = get_dataloader(transform)
    criterion = build_criterion()
    
    num_epochs = 500
    for epoch in range(num_epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            output = output.permute(0, 2, 3, 1).contiguous()
            features = output.shape[-1]
            output = output.view(-1, features)  # (batchsize*width*height, features)
            target = target.view(-1)  # (batchsize*width*height)
            
            loss = criterion(output, target)
            optimizer = optim.AdamW(model.parameters(), lr=0.0001)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 10 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch+1, num_epochs, batch_idx+1, len(train_loader), loss.item()))
        
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), CHECKPOINT_PATH)
            print(f"Model weights saved at epoch {epoch+1}.")

 网络输出的tensor的形状是(batch_size, 特征数,width, height)。将输出reshape成(batch_size*width*height, 特征数),将ground truth形状变成(batch_size*width*height,)这样的一个一维数组,数组中每个元素就是种类的index,代表该像素属于什么类。通过

criterion = nn.CrossEntropyLoss()

损失函数训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1283351.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Qt开发流程】之定时器事件与随机数示例

描述 QObject是所有Qt对象的基类&#xff0c;提供了Qt中基础的定时器支持。通过QObject::startTimer()函数&#xff0c;可以使用毫秒为单位的时间间隔来启动一个定时器。该函数返回一个唯一的整数定时器ID。该计时器现在将以规律的间隔触发&#xff0c;直到显式调用QObject::k…

谷歌用AI模型发现220万种新材料,研究能力超越人类!

谷歌旗下的AI研究机构DeepMind在全球顶级学术期刊《Nature》上发布了一篇论文&#xff0c;通过深度学习、计算机视觉、大数据等&#xff0c;开发了一个名为GNoME的图神经网络模型&#xff0c;主要用于材料发现。 研究团队通过GnoME便快速发现了220万个新的材料晶体结构&#x…

JVM==>图解字节码指令

一&#xff0c;原始代码 我们来看一下执行这段代码的具体流程 那执行这段代码中 JVM就会把已经编译好的.class文件加载到内存中&#xff0c;交给CPU运行 1&#xff09;常量池载入运行时常量池 我们发现 10 并没有被存入常量池中&#xff0c; 这是因为short范围以内的数字不会…

微机原理9

一、单项选择题(本大题共15小题,每小题3分、共45分。在每小题给出的四个备选项中,选出一个正确的答案,请将选定的答案填涂在答题纸的相应位置上。) 8088 系统的内存最大容量为 16MB. 其地址总线为&#xff08;&#xff09; A. 16 位 B. 20 位 C. 24 位 D. 32 位 2,以CPU为核心…

YITH WooCommerce Social Login跨境电商网站社交登录高级版插件

点击阅读YITH WooCommerce Social Login跨境电商网站社交登录高级版插件原文 YITH WooCommerce Social Login跨境电商网站社交登录高级版插件让您的用户节省时间并通过他们的社交资料之一登录或注册网站。 您如何从中受益&#xff1a; 用户无需填写表格、插入个人数据&#…

【数电笔记】06-码制

目录 说明&#xff1a; 二进制代码 1. 二 - 十进制码 2. 常用二 - 十进制代码表 2.1 例题 可靠性代码 1. 格雷码 2. 奇偶校验码 3. 8421奇偶校验码表 说明&#xff1a; 笔记配套视频来源&#xff1a;B站&#xff1b;本系列笔记并未记录所有章节&#xff0c;只对个人认…

计算机基础知识64

ForeignKey属性 to&#xff1a;设置要关联的表 related_name&#xff1a; 反向操作时&#xff0c;使用的字段名&#xff0c;用于代替原反向查询时的’表名_set’ related_query_name:反向查询操作时&#xff0c;使用的连接前缀&#xff0c;用于替换表名 to_field:设置要关联的表…

【数据分享】2015-2023年我国区县逐月二手房房价数据(Excel/Shp格式)

房价是一个城市发展程度的重要体现&#xff0c;一个城市的房价越高通常代表这个城市越发达&#xff0c;对于人口的吸引力越大&#xff01;因此&#xff0c;房价数据是我们在各项城市研究中都非常常用的数据&#xff01;之前我们分享过2015-2023年我国地级市逐月房价数据&#x…

关于你对 Zookeeper 的理解

看看普通人和高手是如何回答这个问题的&#xff1f; 普通人 Zookeeper 是一种开放源码的分布式应用程序协调服务 是一个分布式的小文件存储系统 一般对开发者屏蔽分布式应用开发过过程种的底层细节 用来解决分布式集群中应用系统的一致性问题 高手 对于 Zookeeper 的理解…

【ArcGIS Pro微课1000例】0047:深度学习--棕榈树提取全流程

一、创建训练样本 对汤加科洛瓦伊种植园每棵棕榈树的健康状况进行清查和评估,这需要花费大量的时间和劳动力。 为简化此过程,将在 ArcGIS Pro 中使用深度学习模型来识别树木,然后根据植被绿度的测量值计算其健康状况。 第一步是找到显示汤加科洛瓦伊的影像,该影像具有足够…

VQD视频质量诊断服务/图像质量诊断/视频流质量诊断/传统方法与深度学习结合的视频质量诊断

随着平安城市、大安防的发展&#xff0c;监控摄像机数量的不断增加&#xff0c;给监控系统的维护工作带来了新的挑战。如何及时了解前端视频设备的运行情况&#xff0c;发现故障并检测恶意遮挡与破坏的不法行为已成为视频监控系统运行的首要迫切问题。对于成千上万个监控摄像机…

Vue3 组合式实现 带连接线的Tree型 架构图(一级树形图)

创建组件名称 TreeNodeView.vue <template><div class"tree-node"><div class"node">{{ rootNodeName }}</div><div class"children" :style"childrenLineStyle"><div class"child-node"…

练习十一:简单卷积器的设计

简单卷积器的设计 1&#xff0c;任务目的&#xff1a;2&#xff0c;明确设计任务2.1,目前这部分代码两个文件没找到&#xff0c;见第5、6节&#xff0c;待解决中。 &#xff0c;卷积器的设计&#xff0c;RTL&#xff1a;con1.v4&#xff0c;前仿真和后仿真&#xff0c;测试信号…

一键自动修改和翻新OC源码,解决苹果审核4.3和马甲问题

ipaguard 自动修改/翻新/混淆/OC/iOS代码&#xff0c;自动替换类名&#xff0c;方法名 由来 网上有很多关于如何混淆iOS源码的方法&#xff0c;但是都不够智能&#xff0c;生成的方法类名要么千奇百怪&#xff0c;要么aaaabbbxxx这种完全毫无意义的名称&#xff0c;要么只能…

全网最新最全的自动化测试:python+pytest接口自动化-接口测试基础

接口定义 一般我们所说的接口即API&#xff0c;那什么又是API呢&#xff0c;百度给的定义如下&#xff1a; API&#xff08;Application Programming Interface&#xff0c;应用程序接口&#xff09;是一些预先定义的接口&#xff08;如函数、HTTP接口&#xff09;&#xff0c…

TEMU跨境平台与亚马逊检测认证几大认证您知道多少?

TEMU跨境平台与亚马逊检测认证几大认证您知道多少&#xff1f; TEMU跨境平台与亚马逊对于做外贸的人应该都不陌生,可是你是否知道产品入驻TEMU跨境平台与亚马逊需要办理的13大认证呢?如果你不知道,请认真阅读正面的内容,因为它关系着你的产品能否在TEMU跨境平台与亚马逊顺利上…

零基础学编程,中文编程工具构件之弹出菜单构件教程,中文编程工具下载

一、前言&#xff1a; 零基础自学编程&#xff0c;中文编程工具下载&#xff0c;中文编程工具构件之扩展系统菜单构件教程 编程系统化教程链接https://jywxz.blog.csdn.net/article/details/134073098?spm1001.2014.3001.5502 给大家分享一款中文编程工具&#xff0c;零基础…

二进制动态插桩工具intel PIN的学习笔记

前言 最近两周为了课程汇报学习了intel PIN这个动态插桩&#xff08;dynamic instrument&#xff09;工具&#xff0c;总体的学习感受还是挺累的。一方面&#xff0c;这个方向比较小众&#xff0c;相关的二手资料比较少&#xff0c;能参考的也就只有官方手册这种一手资料&…

12.4c++中的继承

#include <iostream>using namespace std;class Sofa { private:string way;int *score; public:Sofa(){}//有参构造函数Sofa(string way,int score):way(way),score(new int(score)){cout << "Sofa::有参构造函数" << endl;}//拷贝构造函数Sofa(c…

从声纹模型到语音合成:音频处理 AI 技术前沿 | 开源专题 No.45

facebookresearch/audiocraft Stars: 16.6k License: MIT AudioCraft 是一个用于音频生成的 PyTorch 库。它包含了两个最先进的 AI 生成模型 (AudioGen 和 MusicGen) 的推理和训练代码&#xff0c;可以产生高质量音频。该项目还提供了其他功能&#xff1a; MusicGen&#xf…