COCO_03 制作COCO格式数据集 dataset 与 dataloader

news2025/1/15 22:43:51

文章目录

  • 1 引言
  • 2 pycocotools介绍
  • 3 Dataset 构建
  • 4 Dataloader 构建
    • 4.1 解决batch中tensor维度不一致的打包问题
    • 4.2 collate_fn()函数分析
  • Appendix
    • A. convert_coco_poly_mask
    • B. COCO_Transform
  • 参考

1 引言

在之前的文章中,我们认识了COCO数据集的基本格式https://blog.csdn.net/qq_44776065/article/details/128695821和制作了分割数据集 制作COCO格式目标检测和分割数据集https://blog.csdn.net/qq_44776065/article/details/128697177,那么接下来如何读取数据集,并展示结果呢?接下来我们解决这个问题

2 pycocotools介绍

pycocotools是官方给出的解析COCO格式数据集的API,帮助我们对COCO格式数据集进行操作,官方API:https://github.com/cocodataset/cocoapi,在PythonAPI中有Demo,可以下载后运行

安装pycocotools(本人安装时Linuxwindows都可使用)

pip install pycocotools

重要属性

  • 图片的字典信息:coco.imgs
  • 标注的字典信息:coco.anns
  • 类别的字典信息:coco.cats

重要API,getload
在这里插入图片描述
基本思想:先获取ID,再加载信息

获取ID:

  • 获取所有图片的ID:getImgIds()指定ID回返回指定的ID
  • 根据imgIdscatIds获取标注ID:getAnnIds(imgIds=[],catsIds=[])
  • 获取类别ID:getCatIds()

加载信息:

  • 加载图片信息loadImgs(img_id),获取的是字典信息,获取路径信息为: loadImgs(img_id)[0]["file_name"]
  • 加载标注信息loadAnns(ann_ids)ann_ids来自筛选的的标注id
  • 加载类别信息loadCats(cat_id)

例子:初始化COCO对象,并获取图片ID

from pycocotools.coco import COCO
import os

dataset_root = "D:MyDataset/my_coco"
anno_file = "my_annotations.json"

anno_path = os.path.join(dataset_root, anno_file)
anno = COCO(anno_path)

image_ids = anno.getImgIds()
 

3 Dataset 构建

基本流程:

  1. 初始化:初始化COCO数据集,并获取所有图片的ID
  2. 获取图片信息,根据index获取图片ID,再根据图片ID(或者类别ID)获取标注ID
  3. 获取标注信息,根据标注ID,加载标注信息
  4. 对标注信息进行处理,转化为tensor

初始化:

from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from pycocotools.coco import COCO

from utils import convert_coco_poly_mask, draw_gt
import  utils.coco_transform as coco_transform


class SegDatasetCOCO(Dataset):

    def __init__(self, dataset_root, p_anno_filename, category, transforms) -> None:
        super(SegDatasetCOCO).__init__()

        # 根据数据集和p_dir获取标注文件
        assert os.path.exists(dataset_root), "{0} does not exists".format(dataset_root)
        anno_root = os.path.join(dataset_root, p_anno_filename)
        self.patient_dir = p_anno_filename[0: -17]

        self.transforms = transforms
        self.category = category

        # 加载COCO数据
        self.anno = COCO(annotation_file=anno_root)
        
        # 获取其中的数据
        # self.ids = list(self.anno.imgs.keys())
        self.ids = self.anno.getImgIds()
        self.dataset_root = dataset_root

        # 输出目录信息
        print(f"Dataset Info Name: {self.patient_dir}")
        print(f"Dataset Info dataset len: {len(self.ids)}")

获取单个batch:

def __getitem__(self, index):

    # 获取图片和标注
    img_id = self.ids[index]

    # 读取图片
    filename = self.anno.loadImgs(img_id)[0]["file_name"]
    filepath = os.path.join(self.dataset_root, filename)
    images = Image.open(filepath).convert("L")

    # 获取标注
    w, h = images.size
    # 根据图片ID和类别ID获取标注ID
    anno_ids = self.anno.getAnnIds(imgIds=img_id, catIds=self.category)
    coco_targets = self.anno.loadAnns(anno_ids)

    # # 选择标签
    # coco_targets = [item for item in coco_targets if item["category_id"] == self.category]

    target = self.parse_targets(img_id=img_id, coco_targets=coco_targets, w=w, h=h)

    # 返回处理后的数据
    if self.transforms is not None:
        images, target = self.transforms(images, target)

    return images, target

对标注信息处理:

def parse_targets(self,
                  img_id: int,
                  coco_targets: list,
                  w: int = None,
                  h: int = None):

    assert w > 0, "w 不合法"
    assert h > 0, "h 不合法"

    # 只筛选出单个对象的情况
    anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]
    boxes = [obj["bbox"] for obj in anno]

    # 转化为tensor格式, box的格式: [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
    boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
    boxes[:, 2:] += boxes[:, :2]
    boxes[:, 0::2].clamp_(min=0, max=w)
    boxes[:, 1::2].clamp_(min=0, max=h)

    # 类别标签
    classes = [obj["category_id"] for obj in anno]
    classes = torch.tensor(classes, dtype=torch.int64)

    # 面积
    area = torch.tensor([obj["area"] for obj in anno])
    iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])

    # 分割标签转化为图片
    segmentations = [obj["segmentation"] for obj in anno]
    masks = convert_coco_poly_mask(segmentations, h, w)


    # 筛选出合法的目标,即 x_max>x_min 且 y_max>y_min
    keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
    boxes = boxes[keep]
    classes = classes[keep]
    masks = masks[keep]
    area = area[keep]
    iscrowd = iscrowd[keep]

    target = {}
    target["boxes"] = boxes
    target["labels"] = classes
    target["masks"] = masks
    target["image_id"] = torch.tensor([img_id])
    target["area"] = area
    target["iscrowd"] = iscrowd
    return target

4 Dataloader 构建

创建DatasetDataLoader

dataset_root = r"D:\Learning\OCT\oct-dataset-master\dataset\dataset_stent_coco"
p_anno_filename = "P9_1_IMG002_annotations.json"
category = 2

transforms = coco_transform.Compose([coco_transform.ToTensor()])

dataset = SegDatasetCOCO(
    dataset_root=dataset_root, 
    p_anno_filename=p_anno_filename, 
    category=category,
    transforms=transforms
    )


dataset_loader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=dataset.collate_fn
    )

4.1 解决batch中tensor维度不一致的打包问题

数据集读取需要特殊处理,原因是默认的batch组装无法将结果进行打包,原因是每一张图片的mask的维度不一致,根据目标的个数确定mask的个数

@staticmethod
def collate_fn(batch):
    return tuple(zip(*batch))

4.2 collate_fn()函数分析

batch数据格式,数据均为tensor:

image, {"bbox": [[1, 2, ,3 4], ...], "classes": [1, ...], "mask": [[[1,0, 0], [0, 0, 0], [1, 1, 1,1]], ...], "area": [100.0, ...]}

原理分析:

if __name__ == "__main__":
    a1 = ["a", [1, 2, 3]]    
    a2 = ["b", [3, 4]]   # 第二个的元素维度不一致
    b = [a1, a2]
    c = zip(*(b))
    for i in c:
        print(i)
    pass
# ('a', 'b')
# ([1, 2, 3], [3, 4])

使用*解开a迭代器, 将维度不一致的当作一个元素, 使用zip将两个迭代器对应位置的元素进行组合, 完成batch的合并

如果有不同类的元素

if __name__ == "__main__":
    a1 = ["a", [1, 2, 3]]    
    a2 = ["b", [3, 4]]
    a3 = ["c", {"array": [5, 6]}]
    b = [a1, a2, a3]
    c = tuple(zip(*(b)))
    for i in c:
        print(i)
    pass
# ('a', 'b', 'c')
# ([1, 2, 3], [3, 4], {'array': [5, 6]})

即使多个batch中有不同的元素,这样的情况一般不会出现,常常出现的问题是batch中某个数据维度不一致

Appendix

A. convert_coco_poly_mask

  1. 使用coco_maskpolygon信息转化为rle格式,关于RLE格式,参考:<https: >
  2. 对rle格式进行进行解码,转换为图片mask
  3. 保证mask 维度为3,为打包成batch准备,batch中图片格式:B, C, W, H

from pycocotools import mask as coco_mask

def convert_coco_poly_mask(segmentations, height, width):
    
    masks = []
    for polygons in segmentations:
        rles = coco_mask.frPyObjects(polygons, height, width)
        mask = coco_mask.decode(rles)
        if len(mask.shape) < 3:
            mask = mask[..., None]
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        mask = mask.any(dim=2)   # 有1则为前景
        masks.append(mask)

    if masks:
        masks = torch.stack(masks, dim=0)

    else:
        # 如果mask为空,则说明没有目标,直接返回数值为0的mask
        masks = torch.zeros((0, height, width), dtype=torch.uint8)
    return masks

B. COCO_Transform

再次封装torchvision.transforms.ToTensor等函数,从而对imagetarget同时处理

import random
from torchvision.transforms import functional as F

class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
        return image, target


参考

COCO数据集介绍:https://blog.csdn.net/qq_37541097/article/details/113247318

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/166185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【设计模式】创建型模式·工厂模式

设计模式学习之旅(四) 查看更多可关注后查看主页设计模式DayToDay专栏 一.引子 需求&#xff1a;设计一个咖啡店点餐系统。 设计一个咖啡类&#xff08;Coffee&#xff09;&#xff0c;并定义其两个子类&#xff08;美式咖啡【AmericanCoffee】和拿铁咖啡【LatteCoffee】&…

NoSQLBooster for MongoDB 8.0.1 Crack

最智能的 MongoDB IDE NoSQLBooster 是 MongoDB Server 3.6-6.0 的跨平台 GUI 工具&#xff0c;它提供内置的 MongoDB 脚本调试器、全面的服务器监控工具、链接流畅查询、SQL 查询、查询代码生成器、任务调度、ES2020 支持和高级 IntelliSense经验。新版本 8.0 现已推出&#x…

Laravel文档阅读笔记-How to Build a Rest API with Laravel: A Beginners Guide①

随着移动端和JavaScript框架的发展&#xff0c;比如React和Vue&#xff0c;Restful风格的API越来越流行。使用Restful风格的好处就是一个后端程序可以与多个版本的前端用户界面关联。 Laravel提供了创建Rest API的环境和生态。 首先得导入依赖包比如Laravel Passport和Larave…

MySQL中给字符串字段加索引

文章目录前言一、前缀索引和普通索引二、前缀索引对覆盖索引的影响三、优化前缀索引前言 学完了MySQL索引部分&#xff0c;我们清楚的认识到给子段添加索引可以快速的进行查询&#xff0c;节约时间。但是索引有很多。那么对于字段怎么加索引&#xff0c;加什么索引。加到索引不…

linux基本功系列之useradd命令实战

文章目录一. useradd 命令介绍二. 语法格式及常用选项三. 参考案例3.1 不加任何参数创建用户3.2 创建不能登录系统且没有家目录的用户3.3 创建一个用户&#xff0c;ID为23333.4 创建一个用户并指定其附加组3.5 创建用户并账户过期时间3.6 与useradd相关的目录文件总结前言&…

InfluxDB的查询优化

首先&#xff0c;在学习influxDB的查询优化之前&#xff0c;我们要先学习下InfluxDB的解释器profiler&#xff08;类似于mysql的Explain语句&#xff0c;不一样的是&#xff0c;sql&#xff0c;hivesql是提前查看执行计划等&#xff0c;Influx是在当前查询的最后一页两张表&…

力扣(LeetCode)382. 链表随机节点(2023.01.15)

给你一个单链表&#xff0c;随机选择链表的一个节点&#xff0c;并返回相应的节点值。每个节点 被选中的概率一样 。 实现 Solution 类&#xff1a; Solution(ListNode head) 使用整数数组初始化对象。 int getRandom() 从链表中随机选择一个节点并返回该节点的值。链表中所有…

WhatsApp居然有3个版本?深度详解区别!外贸圈获客神器用起来!

近两年&#xff0c;外贸圈用WhatsApp来营销获客&#xff0c;越来越火。不少走在前头的外贸人&#xff0c;已经尝到了甜头。但也有不少后来者&#xff0c;站在门外张望的时候&#xff0c;整个人都是蒙圈的。❓听说动不动要整几十个账号&#xff0c;还要花老长时间养号&#xff1…

《Linux Shell脚本攻略》学习笔记-第六章

6.1 简介 你开发应用程序的时间越长&#xff0c;就越能体会到有一个能够跟踪程序修订历史的软件是多重要。 大多数Linux发行版中都包含了Git。如果你的系统中还没有安装&#xff0c;可以通过yum或者apt-get获取。 6.2 创建新的git仓库 git中的所有项目都需要有一个用于保存项目…

MyBatis-Plus字段加密解密

项目创建POM依赖 <dependency><!--MyBatis-Plus 企业级模块--><groupId>com.baomidou</groupId><artifactId>mybatis-mate-starter</artifactId><version>1.2.8</version> </dependency> <!-- https://mvnrepository…

规划之路:SLAM学习经验分享

针对想学SLAM的提问&#xff0c;我觉得我还是有一定的发言权。作为一个刚入坑SLAM一年多的初学者&#xff0c;首先想说的就是这个研究方向比较广&#xff0c;大方向按搭载传感器分为激光SLAM和视觉SLAM两种&#xff0c;激光SLAM搭载激光雷达&#xff0c;视觉SLAM搭载单目、双目…

[NSSRound#6 Team]Web学习

[NSSRound#6 Team]Web学习 文章目录[NSSRound#6 Team]Web学习前言一、[NSSRound#6 Team]check(V1)二、[NSSRound#6 Team]check(Revenge)总结前言 日常做点题娱乐下&#xff0c;刷到了[NSSRound#6 Team]中是三道web题&#xff0c;学习到了不少&#xff0c;记录下知识点。 提示&…

C语言综合练习6:制作贪吃蛇

1 初始化界面 因为还没学QT&#xff0c;我们就使用终端界面替代。 这里我们假设界面中没有障碍物&#xff0c;我们只需要设定界面的高宽就行&#xff0c;这是蛇的移动范围&#xff0c;我们可以写两个宏来规定界面的高宽 新建一个snake.c的文件 #define _CRT_SECURE_NO_WARNIN…

快出数量级的性能是怎样炼成的

前言&#xff1a;今天学长跟大家讲讲《快出数量级的性能是怎样炼成的》&#xff0c;废话不多说&#xff0c;直接上干货~我们之前做过一些性能优化的案例&#xff0c;不算很多&#xff0c;还没有失手过。少则提速数倍&#xff0c;多则数十倍&#xff0c;极端情况还有提速上千倍的…

关于IDEA配置本地tomcat部署项目找不到项目工件的问题解答

文章目录一 原因分析二 解决方案三 具体的操作方法3.1 打开项目结构找到工件3.2 添加具体的工件内容3.3 配置本地tomcat一 原因分析 可能是之前的项目再次打开后&#xff0c;没有及时配置项目结构中的工件信息&#xff0c;导致配置tomcat中看不到工件的信息 二 解决方案 解决…

react组件优化,当父组件数据变化与子组件无关时,控制子组件不重新渲染

首先 我们来建立一个场景 我们创建一个react项目 然后创建一个父组件 这里我要叫 record.jsx 参考代码如下 import React from "react"; import Subset from "./subset";export default class record extends React.Component{constructor(props){super(…

工作的同时,我也在这里做副业

文章目录一、什么是独自开&#xff1f;二、独自开能给我们带来什么利益&#xff1f;三、如何使用独自开&#xff1f;3.1、用户任务报价步骤13.2、用户任务报价步骤2四、未来的愿景一、什么是独自开&#xff1f; 独自开&#xff0c;全称独自开发一套系统&#xff0c;是基于商品…

CTP开发(2)行情模块的开发

我在做CTP开发之前&#xff0c;也参考了不少其他的资料&#xff0c;发现他们都是把行情和交易做在同一个工程里的。我呢之前也做过期货相关的交易平台&#xff0c;感觉这种把行情和交易做在一起的方法缺乏可扩展性。比如我开了多个CTP账户&#xff0c;要同时交易&#xff0c;这…

springMVC的学习拦截器之验证用户登录案例

文章目录实现思路关于环境和配置文件pomspring的配置文件关于idea的通病/常见500错误的避坑实现步骤编写登陆页面编写Controller处理请求编写登录成功的页面编写登录拦截器实现思路 有一个登录页面&#xff0c;需要写一个controller访问页面登陆页面提供填写用户名和密码的表单…

UE4c++日记1(允许 创类、蓝图读写/调用/只读、分类、输出日志打印语句)

目录 1允许创建基于xx的蓝图类 2允许蓝图读写/允许蓝图调用/只读 读写调用 只读 3为变量/函数分类 4输出日志打印一段话 1.先创建一个蓝图类 2.构建对象 3.写提示代码&#xff0c;生成解决方案 4.运行&#xff0c;打开“输出日志” 5.总结 创类-实例化对象&#xff08;构建…