【代码整理】基于COCO格式的pytorch Dataset类实现

news2024/11/27 8:39:34

import模块

import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt

基于albumentations库自定义数据预处理/数据增强

class Transform():
    '''数据预处理/数据增强(基于albumentations库)
    '''
    def __init__(self, imgSize):
        maxSize = max(imgSize[0], imgSize[1])
        # 训练时增强
        self.trainTF = A.Compose([
                A.BBoxSafeRandomCrop(p=0.5),
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                A.HorizontalFlip(p=0.5),
                # 参数:随机色调、饱和度、值变化
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),
                # 随机明亮对比度
                A.RandomBrightnessContrast(p=0.2),   
                # 高斯噪声
                A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     
                A.OneOf([
                    # 使用随机大小的内核将运动模糊应用于输入图像
                    A.MotionBlur(p=0.2),   
                    # 中值滤波
                    A.MedianBlur(blur_limit=3, p=0.1),    
                    # 使用随机大小的内核模糊输入图像
                    A.Blur(blur_limit=3, p=0.1),  
                ], p=0.2),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )
        # 验证时增强
        self.validTF = A.Compose([
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )

自定义数据集读取类COCODataset实现


class COCODataset(Dataset):

    def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True):
        '''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径

        Args:
            :param annPath:     COCO annotation 文件路径
            :param imgDir:      图像的根目录
            :param inputShape: 网络要求输入的图像尺寸
            :param trainMode:   训练集/测试集

        Returns:
            FRCNNDataset
        '''      
        self.mode = trainMode
        self.tf = Transform(imgSize=inputShape)
        self.imgDir = imgDir
        self.annPath = annPath
        self.DataNums = len(os.listdir(imgDir))
        # 为实例注释初始化COCO的API
        self.coco=COCO(annPath)
        # 获取数据集中所有图像对应的imgId
        self.imgIds = list(self.coco.imgs.keys())

    def __len__(self):
        '''重载data.Dataset父类方法, 返回数据集大小
        '''
        return len(self.imgIds)

    def __getitem__(self, index):
        '''重载data.Dataset父类方法, 获取数据集中数据内容
           这里通过pycocotools来读取图像和标签
        '''   
        # 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}
        imgId = self.imgIds[index]
        imgInfo = self.coco.loadImgs(imgId)[0]
        # 载入图像 (通过imgInfo获取图像名,得到图像路径)               
        image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))
        image = np.array(image.convert('RGB'))
        # 得到图像里包含的BBox的所有id
        imgAnnIds = self.coco.getAnnIds(imgIds=imgId)   
        # 通过BBox的id找到对应的BBox信息
        anns = self.coco.loadAnns(imgAnnIds) 
        # 获取BBox的坐标和类别
        labels, boxes = [], []
        for ann in anns:
            labelName = ann['category_id']
            labels.append(labelName)
            boxes.append(ann['bbox'])
        labels = np.array(labels)
        boxes = np.array(boxes)
        
        # 训练/验证时的数据增强各不相同
        if(self.mode):
            # albumentation的图像维度得是[W,H,C]
            transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)
        else:
            transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)
        # 这里的box是coco格式(xywh)
        image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']
        return image.transpose(2,0,1), np.array(box), np.array(label)

其他

# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):
    images = []
    bboxes = []
    labels = []
    for img, box, label in batch:
        images.append(img)
        bboxes.append(box)
        labels.append(label)
    images = torch.from_numpy(np.array(images))
    return images, bboxes, labels



# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):
    worker_seed = worker_id + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)


# 固定全局随机数种子
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

batch数据集可视化

def visBatch(dataLoader:DataLoader):
    '''可视化训练集一个batch
    Args:
        dataLoader: torch的data.DataLoader
    Retuens:
        None     
    '''
    catName = {1:'person', 2:'bicycle', 3:'car', 4:'motorcycle', 5:'airplane', 6:'bus',
               7:'train', 8:'truck', 9:'boat', 10:'traffic light', 11:'fire hydrant',
               13:'stop sign', 14:'parking meter', 15:'bench', 16:'bird', 17:'cat', 18:'dog',
               19:'horse', 20:'sheep', 21:'cow', 22:'elephant', 23:'bear', 24:'zebra', 25:'giraffe',
               27:'backpack', 28:'umbrella', 31:'handbag', 32:'tie', 33:'suitcase', 34:'frisbee',
               35:'skis', 36:'snowboard', 37:'sports ball', 38:'kite', 39:'baseball bat',
               40:'baseball glove', 41:'skateboard', 42:'surfboard', 43:'tennis racket',
               44:'bottle', 46:'wine glass', 47:'cup', 48:'fork', 49:'knife', 50:'spoon', 51:'bowl',
               52:'banana', 53:'apple', 54:'sandwich', 55:'orange', 56:'broccoli', 57:'carrot',
               58:'hot dog', 59:'pizza', 60:'donut', 61:'cake', 62:'chair', 63:'couch',
               64:'potted plant', 65:'bed', 67:'dining table', 70:'toilet', 72:'tv', 73:'laptop',
               74:'mouse', 75:'remote', 76:'keyboard', 77:'cell phone', 78:'microwave',
               79:'oven', 80:'toaster', 81:'sink', 82:'refrigerator', 84:'book', 85:'clock',
               86:'vase', 87:'scissors', 88:'teddy bear', 89:'hair drier', 90:'toothbrush'}
    
    for step, batch in enumerate(dataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # 只可视化一个batch的图像:
        if step > 0: break
        # 图像均值
        mean = np.array([0.485, 0.456, 0.406]) 
        # 标准差
        std = np.array([[0.229, 0.224, 0.225]]) 
        plt.figure(figsize = (8,8))
        for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):
            img, box, label = imgBoxLabel
            ax = plt.subplot(4,4,idx+1)
            img = img.numpy().transpose((1,2,0))
            # 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原
            img = img * std + mean
            for instBox, instLabel in zip(box, label):
                x, y, w, h = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])
                # 显示框
                ax.add_patch(plt.Rectangle((x, y), w, h, color='blue', fill=False, linewidth=2))
                # 显示类别
                ax.text(x, y, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})
            plt.imshow(img)
            # 在图像上方展示对应的标签
            # 取消坐标轴
            plt.axis("off")
             # 微调行间距
            plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)
        plt.show()

example

# for test only:
if __name__ == "__main__":
    # 固定随机种子
    seed = 23
    seed_everything(seed)
    # BatcchSize
    BS = 16
    # 图像尺寸
    imgSize = [800, 800]

    trainAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_train2017.json"
    testAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_val2017.json"
    imgDir =  "E:/datasets/Universal/COCO2017/train2017"
    # 自定义数据集读取类
    trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True)
    trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size = BS, num_workers=2, pin_memory=True,
                                    collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))
    # validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)
    # validDataLoader = DataLoader(validDataset, shuffle=True, batch_size = BS, num_workers = 1, pin_memory=True, 
                                  # collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))



    print(f'训练集大小 : {trainDataset.__len__()}')
    visBatch(trainDataLoader)
    for step, batch in enumerate(trainDataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # torch.Size([bs, 3, 800, 800])
        print(f'images.shape : {images.shape}')   
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一
        print(f'len(boxes) : {len(boxes)}')     
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一  
        print(f'len(labels) : {len(labels)}')     
        break

输出

在这里插入图片描述

images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1399649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring MVC精解:技术内幕与最佳实践

第1章:引言 大家好,我是小黑,咱们今天来聊聊Spring MVC,它是Spring的一个模块,专门用来构建Web应用程序。提供了一种轻量级的方式来构建动态网页。就像小黑我刚开始接触Java时候一样,可能对这些听起来很高…

GitHub 一周热点汇总第6期(2024/01/14-01/20)

GitHub一周热点汇总第6期 (2024/01/14-01/20) ,梳理每周热门的GitHub项目,这一周的热门项目中AI的比重难得的变低了,终于不像一个AI热门项目汇总了,一起来看看都有哪些项目吧。 #1Maybe 项目名称:Maybe - 个人理财应…

4496 蓝桥杯 求函数零点 简单

4496 蓝桥杯 求函数零点 简单 //C风格解法1&#xff0c;通过率100% #include <bits/stdc.h> // int a, b; 一定会自动初始化为 0int main(){int a 2, b 3; // 定义a&#xff0c;b&#xff0c;不会自动初始化&#xff0c;最好自己定义时初始化// windows环境下a值固定&…

Broadcom交换芯片56620架构

文章目录 架构1.系统逻辑视图2.逻辑芯片视图3.芯片框图4.MIIM&#xff08;Medium Independent Interface Management&#xff09;5.交换结构6.CAP 架构 1.系统逻辑视图 Ingress Chip作用&#xff1a; 解析报文128字节的头部&#xff08;MMU&#xff08;Memory Management Uni…

html5实现好看的年会邀请函源码模板

文章目录 1.设计来源1.1 邀请函主界面1.2 诚挚邀请界面1.3 关于我们界面1.4 董事长致词界面1.5 公司合作方界面1.6 活动流程界面1.7 加盟支持界面1.8 加盟流程界面1.9 加盟申请界面1.10 活动信息界面 2.效果和源码2.1 动态效果2.2 源码目录结构 源码下载 作者&#xff1a;xcLei…

dpwwn:03

靶场下载 https://download.vulnhub.com/dpwwn/dpwwn-03.zip 信息收集 # nmap -sn 192.168.1.0/24 -oN live.nmap Starting Nmap 7.94 ( https://nmap.org ) at 2024-01-17 21:18 CST Stats: 0:00:00 elapsed; 0 hosts completed (0 up), 255 undergoing ARP Ping Sc…

力扣:494. 目标和(动态规划)(01背包)

题目&#xff1a; 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 ‘’ 或 ‘-’ &#xff0c;然后串联起所有整数&#xff0c;可以构造一个 表达式 例如&#xff0c;nums [2, 1] &#xff0c;可以在 2 之前添加 ‘’ &#xff0c;在 1 之前添加…

【设计模式】什么是外观模式并给出例子!

什么是外观模式&#xff1f; 外观模式是一种结构型设计模式&#xff0c;主要用于为复杂系统、库或框架提供一种简化的接口。这种模式通过定义一个包含单个方法的高级接口&#xff0c;来隐藏系统的复杂性&#xff0c;使得对外的API变得简洁并易于使用。 为什么要使用外观模式&a…

Leetcode的AC指南 —— 栈与队列:225.用队列实现栈

摘要&#xff1a; **Leetcode的AC指南 —— 栈与队列&#xff1a;225.用队列实现栈 **。题目介绍&#xff1a;请你仅使用两个队列实现一个后入先出&#xff08;LIFO&#xff09;的栈&#xff0c;并支持普通栈的全部四种操作&#xff08;push、top、pop 和 empty&#xff09;。 …

【flutter】完全自定义样式模态对话框

示例完成结果展示&#xff1a; 示例组件代码&#xff1a; context&#xff1a;上下文 title&#xff1a;提示标题&#xff0c;null时不显示 content&#xff1a;提示内容&#xff0c;null时不显示 cancelText&#xff1a;取消按钮文字&#xff0c;null时不显示取消按钮 confirm…

Canny边缘检测 双阈值检测理解

问题引入 我们用一个实际例子来引入问题 import cv2 import numpy as npimgcv2.imread("test.png",cv2.IMREAD_GRAYSCALE) # 修改图像大小 show cv2.resize(img,(500,500))v1cv2.Canny(show,120,250) v2cv2.Canny(show,50,100)# 连接图像 res np.hstack((v1,v2)…

MSPM0L1306例程学习-UART部分(2)

MSPM0L1306例程学习系列 1.背景介绍 写在前边的话&#xff1a; 这个系列比较简单&#xff0c;主要是围绕TI官网给出的SDK例程进行讲解和注释。并没有针对模块的具体使用方法进行描述。所有的例程均来自MSPM0 SDK的安装包&#xff0c;具体可到官网下载并安装: https://www.ti…

java枚举详细解释

枚举的基本认识 我们一般直接定义一个单独的枚举类 public enum 枚举类名{枚举项1,枚举项2,枚举项3 } 可以通过 枚举类名.枚举项 来访问该枚举项的 - 可以理解为 枚举项就是我们自己定义的一个数据类型,是独一无二的 接下来我们直接用一个例子来完全理解 加深理解 这里…

[C#]winform部署openvino官方提供的人脸检测模型

【官方框架地址】 https://github.com/sdcb/OpenVINO.NET 【框架介绍】 OpenVINO&#xff08;Open Visual Inference & Neural Network Optimization&#xff09;是一个由Intel推出的&#xff0c;针对计算机视觉和机器学习任务的开源工具套件。通过优化神经网络&#xff…

vtk qt切割stl模型

一直想实现对stl模型的某个方向进行平面切割 通过滑动slider然后对模型进行某一个方向的面切割。同时可以用鼠标对模型进行移动缩放&#xff0c;旋转等操作。然后可以加一些颜色点云显示等操作。 stl加载&#xff1a; QString selectFilePath QFileDialog::getOpenFileName…

孚盟云 多处SQL注入漏洞复现

0x01 产品简介 上海孚盟软件有限公司是一家外贸SaaS服务提供商,也是专业的外贸行业解决方案专业提供商。 全新的孚盟云产品,让用户可以用云模式实现信息化管理,让用户的异地办公更加流畅,大大降低中小企业在信息化上成本,用最小的投入享受大型企业级别的信息化服务,使中…

六、标准对话框、多应用窗体

一、标准对话框 Qt提供了一些常用的标准对话框&#xff0c;如打开文件对话框、选择颜色对话框、信息提示和确认选择对话框、标准输入对话框等。1、预定义标准对话框 &#xff08;1&#xff09;QFileDialog 文件对话框 QString getOpenFileName() 打开一个文件QstringList ge…

《JVM由浅入深学习九】 2024-01-15》JVM由简入深学习提升分(生产项目内存飙升分析)

目录 开头语内存飙升问题分析与案例问题背景&#xff1a;我华为云的一个服务器运行我的一个项目“csdn-automatic-triplet-0.0.1-SNAPSHOT.jar”&#xff0c;由于只是用来测试的服务器&#xff0c;只有2G&#xff0c;所以分配给堆的内存1024M查询内存使用&#xff08;top指令&a…

Self-RAG:通过自我反思学习检索、生成和批判

论文地址&#xff1a;https://arxiv.org/abs/2310.11511 项目主页&#xff1a;https://selfrag.github.io/ Self-RAG学习检索、生成和批评&#xff0c;以提高 LM 的输出质量和真实性&#xff0c;在六项任务上优于 ChatGPT 和检索增强的 LLama2 Chat。 问题&#xff1a;万能L…

活性白土数据研究:预计2029年将达到9.2亿美元

活性白土是用粘土(主要是膨润土)为原料&#xff0c;经无机酸化或盐或其他方法处理&#xff0c;再经水漂洗、干燥制成的吸附剂&#xff0c;外观为乳白色粉末&#xff0c;无臭&#xff0c;无味&#xff0c;无毒&#xff0c;吸附性能很强&#xff0c;能吸附有色物质、有机物质。广…