如何创建一个pytorch的训练数据加载器(train_loader)用于批量加载训练数据

news2025/2/28 15:52:24

Talk is cheap,show me the code!  哈哈,先上几段常用的代码,以语义分割的DRIVE数据集加载为例:

DRIVE数据集的目录结构如下,下载链接DRIVE,如果官网下不了,到Kaggle官网可以下到:

1. 定义DriveDataset类,每行代码都加了注释,其中collate_fn()看不懂没关系:

import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):  # 继承Dataset类
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test"  # 根据train这个布尔类型确定需要处理的是训练集还是测试集
        data_root = os.path.join(root, "DRIVE", self.flag)  # 得到数据集根目录
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."   # 判断路径是否存在
        self.transforms = transforms   # 初始化图像变换操作
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]   # 遍历图像文件夹获取每个图像的文件名
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]  # 获取图像路径
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")  # 获取手动标签的路径
                       for i in img_names]
        # 检查手动标签文件是否存在
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")
        # 获取分割的ROI区域掩码
        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
                         for i in img_names]
        # check files
        for i in self.roi_mask:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert('RGB')  # 加载图像,并转换为RGB模式
        manual = Image.open(self.manual[idx]).convert('L')   # 加载手动标注图像,并转换为灰度模式
        manual = np.array(manual) / 255   # 进行归一化操作
        roi_mask = Image.open(self.roi_mask[idx]).convert('L')  # 加载ROI图像,并转换为灰度模式
        roi_mask = 255 - np.array(roi_mask)   # 对图像数组取反,使用这个方法将背景和前景颜色反转,白色是255,黑色是0,反转后ROI变成了内黑外白
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)   # 将手动标注图像和反转后的ROI图像相加,使用np.clip()将像素值控制在0-255范围,

        # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
        mask = Image.fromarray(mask)

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    # 获取图像数据集长度
    def __len__(self):
        return len(self.img_list)

    # 用于将批量的图像和标签数据合并为一个批张量。
    @staticmethod
    def collate_fn(batch):
        images, targets = list(zip(*batch))  # 将批量数据拆分为图像和标签两个列表
        batched_imgs = cat_list(images, fill_value=0)  # 使用 cat_list() 函数将图像和标签列表合并成张量。用于将列表中的 PIL 图像数据堆叠成张量,
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))  # 找到图像中最大的形状,以元组形式返回给max_size
    batch_shape = (len(images),) + max_size   # 计算出堆叠后的张量形状,包括批量大小和图像大小两个维度
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)  # 创建一个新的空白张量 batched_imgs,其形状与 batch_shape 相同,并将其填充为指定的填充值 fill_value
    for img, pad_img in zip(images, batched_imgs):  # 使用 zip() 函数将输入列表中的每个图像与其对应的空白张量进行拼接,以得到一个完整的张量。
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)  # 将每个图像按照其实际大小插入到空白张量的左上角,以保持图像的相对位置不变。
    return batched_imgs

2. 构建训练集和验证集对象。调用上述自定义的DriveDataset数据集类,通过传入不同的参数来区分训练集和验证集。arg.data_path表示数据集所在的路径,transforms参数则是表示对数据进行预处理的操作,包括图像增强和归一化等,mean和std是规范化处理时用到的均值和标准差。get_transform这个类下面会介绍。train=True表示构建训练集对象,train=False表示构建验证集对象。

train_dataset = DriveDataset(args.data_path,
                                 train=True,
                                 transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
                               train=False,
                               transforms=get_transform(train=False, mean=mean, std=std))

3.这是定义图像预处理方式,包括训练集和测试集的图像和标签的预处理方式,每行代码的具体作用注释有介绍。

# 定义训练集图像的预处理方式
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(1.2 * base_size)

        trans = [T.RandomResize(min_size, max_size)]   # 对图像的短边(长和宽中最短的)进行随机缩放以适应不同图像输入尺寸,缩放范围为【min_size, max_size】
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))  # 加入水平翻转
        if vflip_prob > 0:
            trans.append(T.RandomVerticalFlip(vflip_prob))  # 加入垂直翻转
        trans.extend([
            T.RandomCrop(crop_size),   # 对图像进行随机裁剪
            T.ToTensor(),  # 将数组矩阵转换为tensor类型,规范化到【0,1】范围
            T.Normalize(mean=mean, std=std),  # 加入图像归一化,并定义均值和标准差,RGB三通道的
        ])
        # trans是一个列表类型,包含各种了变换,将这些变换组成一个compose变换,注意transforms.Compose()函数需要接收一个列表类型
        self.transforms = T.Compose(trans)

    # 使用__call__()函数来调用transforms变换
    def __call__(self, img, target):
        return self.transforms(img, target)  # target是指标签图像,img是指待分割图像


# 定义验证集的图像预处理组合类,比较简单,只有张量化和规范化两个操作,这里规范化使用的是ImageNet推荐的参数,注意这种做法是针对彩色图像
class SegmentationPresetEval:
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)

# 定义一个函数根据数据集的类型来调用对应的数据集处理类
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    base_size = 565
    crop_size = 480
    # 检查train是否为True
    if train:
        return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
    else:
        return SegmentationPresetEval(mean=mean, std=std)

4. 定义训练时所使用的线程数目,这里如果时windows系统训练出错的话,建议把num_workers直接设置为0就可以解决。定义训练数据加载器train_loader,用于批量加载训练数据。在代码中,使用torch.utils.data.DataLoader类来创建数据加载器,构造函数的参数包括:

  • train_dataset:训练数据集,应该是一个符合 PyTorch Dataset 接口的对象。
  • batch_size:每个批次的样本数量。
  • num_workers:用于数据加载的线程数。
  • shuffle:是否在每个时期(epoch)重新洗牌数据。一般只在训练集用
  • pin_memory:是否将数据加载到固定的内存区域,可以加速数据传输。
  • collate_fn:用于将样本列表转换为批次张量的函数。

num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])    # 如果batch_size>1, 线程数num_workers取min(cpu核数,batch_size)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

至此数据集加载器创建完成!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1382519.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt OpenGL - 网格式的直角坐标系

Qt OpenGL - 网格式的直角坐标系 引言一、绘制3D网格1.1 绘制平行于y轴的线段1.2 绘制平行于三个轴的线段1.3 绘制不同的3D网格 二、网格式的直角坐标系三、参考链接 引言 在OpenGL进行3D可视化,只绘制三条坐标轴略显单薄,而绘制网格形式的坐标系则能更清…

更换为mainwindow.ui更新工程架构

文章目录 前言一、新建带mainwindow.ui的工程1.新建工程2. 添加工程模块添加opencv的库3.添加资源3.1工程上添加资源3.2引用资源 4.添加曲线文件4.1 复制关键文件到新工程4.2 新进显示曲线的ui带.h的为了方面名字取一样4.3添加曲线显示控件4.4 添加工具 5. 添加曲线.h文件内容6…

大数据深度学习ResNet深度残差网络详解:网络结构解读与PyTorch实现教程

文章目录 大数据深度学习ResNet深度残差网络详解:网络结构解读与PyTorch实现教程一、深度残差网络(Deep Residual Networks)简介深度学习与网络深度的挑战残差学习的提出为什么ResNet有效? 二、深度学习与梯度消失问题梯度消失问题…

Apache-Common-Pool2中对象池的使用方式

最近在工作中,对几个产品的技术落地进行梳理。这个过程中发现一些朋友对如何使用Apache的对象池存在一些误解。所以在写作“业务抽象”专题的空闲时间里,本人觉得有必要做一个关于对象池的知识点和坑点讲解。Apache Common-Pool2 组件最重要的功能&#…

nvm安装高版本Nodejs报错

文章概叙 之前使用1.1.17版本的nvm,切换使用18的Nodejs的时候报错,经过短暂的思考,决定使用1.1.12的nvm的无聊故事。 吐槽 今天的故事比较无奈,由于某些原因,现在需要做rn的开发,至于为啥不是flutter&am…

《工具录》dig

工具录 1:dig2:选项介绍3:示例4:其他 本文以 kali-linux-2023.2-vmware-amd64 为例。 1:dig dig 是域名系统(DNS)查询工具,常用于域名解析和网络故障排除。比 nslookup 有更强大的功…

一张图总结架构设计的40个黄金法则

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中,很多小伙伴拿到非常优质的架构机会,常常找尼恩求助: 尼恩,我这边有一个部门技术负责人资深架构师的机会,非常难得, 但是有一个大厂高P在抢&#xff0…

为什么很多公司选择不升级JDK版本,仍然使用JDK8?

在讨论为什么许多公司选择不升级JDK版本,而继续使用JDK 8时,我们需要从多个角度来分析这个问题。以下是根据您提供的背景信息进行的一些分析和真实案例。 本文已收录于,我的技术网站 ddkk.com,有大厂完整面经,工作技术…

H5网站封装成App的高效转换之旅

在移动互联网时代,App(应用程序)和H5(HTML5网站)是两种常见的移动解决方案。App通常提供更流畅的用户体验和更丰富的功能,而H5网站则以其开发成本低、更新快捷和无需安装等优势受到青睐。尽管如此&#xff…

【java八股文】之Spring系列篇

1、你怎么理解Spring? Spring是个轻量级的框架,简化了应用的开发程序,提高开发人员的系统维护性,不过配置消息比较繁琐,所以后面才出选了SpringBoot的框架。 Spring的核心组件 : Spring Core 、 Spring Con…

Video接口介绍

屏库 https://m.panelook.cn/index_cn.php Open LDI, open lvds display interface OpenLDI and LVDS是兼容的, 是一种电平 https://www.ti2k.com/178597.html MIPI DSI/Camera crosLink FPD-LINK(Flat panel display link)是National(TI) LVDS技术, …

Openstack云计算(六)Openstack环境对接ceph

一、实施步骤: (1)客户端也要有cent用户: useradd cent && echo "123" | passwd --stdin cent echo -e Defaults:cent !requiretty\ncent ALL (root) NOPASSWD:ALL | tee /etc/sudoers.d/ceph chmod 440 /et…

[足式机器人]Part2 Dr. CAN学习笔记-Advanced控制理论 Ch04-12+13 不变性原理+非线性系统稳定设计

本文仅供学习使用 本文参考: B站:DR_CAN Dr. CAN学习笔记-Advanced控制理论 Ch04-1213 不变性原理非线性系统稳定设计 1. Invariance Princilpe-LaSalle;s Theorem不变性原理2. Nonlinear Basic Feedback Stabilization 非线性系统稳定设计 1. Invarianc…

Java内存模型之重排序

文章目录 1.什么是重排序2.重排序的好处3.重排序的三种情况4.用volatile修正重排序问题 1.什么是重排序 首先来看一个代码案例,尝试分析一下 x 和 y 的运行结果。 import java.util.concurrent.CountDownLatch;/*** 演示重排序的现象,直到达到某个条件…

css深度选择器 /deep/

一、/deep/的含义和使用 /deep/ 是一种 CSS 深度选择器,也被称为深度组合器或者阴影穿透组合器,主要用在 Web 组件样式封装中。 在 Vue.js 或者 Angular 中,使用了样式封装技术使得组件的样式不会影响到全局,也就是说组件内部的…

java数据结构与算法:单链表 SinglyLinkedList

单链表 SinglyLinkedList 创建实现类并实现方法 package com.lhs;public class SinglyLinkedList<E> implements List<E>{// 头节点private Node<E> first;// 尾节点private Node<E> last;// 节点数量private int size;public static class Node<…

PDCA/绩效管理活动

现代绩效管理理论认为&#xff0c;绩效管理活动是一个连续的过程&#xff0c;是指管理者用来确保自己下属员工的工作行为和工作产出与组织的目标保持一致的手段及过程。人们通常用一个循环过程来描述绩效管理的整个过程。我们认为&#xff0c;一个组织的员工绩效管理活动由四个…

Dockerfile的ADD和COPY

文章目录 环境ADD规则校验远程文件checksum添加Git仓库添加私有Git仓库ADD --link COPYCOPY --parent 使用ADD还是COPY&#xff1f;参考 环境 RHEL 9.3Docker Community 24.0.7 ADD ADD 指令把 <src> 的文件、目录、或URL链接的文件复制到 <dest> 。 ADD 有两种…

element表格数据,表头上(下)角标,html字符串渲染

1. 问题描述 在动态渲染的element表格中&#xff0c;表头和表中数据是一个含有html的字符串&#xff0c;需要渲染 2. 效果 3. 代码 const columns ref([{ text: 差值<sub>-3</sub> / 10<sup>-6</sup>℃<sup>-1</sup>, value: aallowEr…