pytorch中Dataset、Dataloader、Sampler、collate_fn相互关系和使用说明

news2024/11/22 21:27:00

参考: https://blog.csdn.net/Chinesischguy/article/details/103198921

参考: https://zhuanlan.zhihu.com/p/76893455

参考:https://blog.csdn.net/lilai619/article/details/118784730

参考:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

        本博客旨在介绍PyTorch深度学习框架中Dataset、Dataloader、Sampler、collate_fn组件之间相互关系,以及如何自定义各组件。这些组件是深度学习项目中不可或缺的组成部分,对于理解和使用PyTorch框架进行深度学习任务至关重要。

        在PyTorch深度学习框架中,Dataset、Dataloader、Sampler和collate_fn是数据加载和处理过程中非常重要的组成部分。它们之间的调用关系如下:

  1. Dataset:定义了数据集的接口,用于读取和处理数据。通常情况下,Dataset是从文件或数据库中读取数据的集合,它可以对数据进行预处理、增强等操作,并返回一个可迭代的对象,用于后续的数据加载过程。

  2. Dataloader:实现了数据集的批量加载功能。Dataloader可以根据Dataset返回的可迭代对象,将数据分成多个batch,并按照指定的采样方式(如随机采样、分层采样等)进行采样。同时,Dataloader还可以自动调整batch size、设置数据加载器状态等。

  3. Sampler:定义了数据集中每个batch所包含的数据的位置索引。通常情况下,Sampler是在数据加载之前设置的一个对象,它可以根据用户指定的要求(如按照类别、标签等)对数据集进行采样,并返回每个batch所包含的数据的位置索引。

  4. collate_fn:用于将一个batch中的数据进行拼接和整理。通常情况下,collate_fn是在Dataloader创建时设置的一个函数,它可以根据Dataset返回的可迭代对象和Sampler返回的位置索引,将不同长度的输入数据转换为统一的形状,并返回一个新的tensor作为batch的数据。

        综上所述,Dataset、Dataloader、Sampler和collate_fn之间是相互协作的,它们共同完成了数据加载和处理的过程。具体来说,Dataset提供了数据集的接口和一些基本的操作;Dataloader实现了数据的批量加载和一些高级的功能;Sampler根据用户指定的要求对数据集进行采样;collate_fn负责将不同长度的输入数据转换为统一的形状。本文将讨论这四个组件的使用方法,并提供一些自定义各组件的技术实践经验。我们将从以下几个方面来探讨:

        1. Dataset的使用方法和自定义技巧;

        2. Sampler的使用方法和自定义技巧;

        3. collate_fn的使用方法和自定义技巧。

DataLoader, Sampler, Dataset三者的关系

        1. Sampler提供indicies

        2. Dataset根据indicies提供data,使用__getitem__方法

        3. DataLoader将上面两个组合起来,提供最终的batch训练数据,其中collate_fn可以对batch中的数据做额外的处理

自定义Dataset

        在PyTorch中,可以通过继承torch.utils.data.Dataset类来自定义数据集(Dataset)类。自定义的数据集类可以包含自己的数据加载和预处理方法,以及一些额外的元数据。

import torch
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
import torchvision
from torchvision.io import read_image
import random
import numpy as np
from matplotlib import pyplot as plt
from collections import Counter


class MyDataset(Dataset):
    """
        加载磁盘上的图像文件,并进行transform变换,返回变换后的图片和与之对应的标签编号
    """

    def __init__(self, filenames, labels, transforms_pipeline=None):
        super().__init__()
        # 所有图像的路径列表
        self.filenames = filenames
        # 所有图片对应的label标签编号,从0开始
        self.labels = labels
        # 图像预处理
        self.transforms_pipeline = transforms_pipeline

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filepath = self.filenames[idx]
        img = read_image(filepath, mode=torchvision.io.ImageReadMode.RGB)
        if self.transforms_pipeline:
            img = self.transforms_pipeline(img)
        return img, self.labels[idx]

        以上代码自定义了一个Dataset类用于加载训练数据,训练数据中cat和dog目录下分别存储的是猫和狗的图片。

         使用以下代码片段测试自定义的Dataset数据加载情况:

transforms_pipeline = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
    ]
)

# 图像存放位置,其中包含两个目录,cat和dog,cat下存放猫的图片,dog下存放狗的图片
data_path = "XXX"
image_folder = torchvision.datasets.ImageFolder(data_path)
# image_folder.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_folder.samples)
# image_folder.classes image_folder.samples中存放的类别索引编号相对应
classes = image_folder.classes
# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_folder.samples:
    # print(image_path, label)
    filenames.append(image_path)
    labels.append(label)
print(filenames, labels)

# 使用自定义Dataset类加载磁盘上的图上数据
my_dataset = MyDataset(filenames, labels, transforms_pipeline)
img, label = my_dataset[10]
print(img.shape, label)

自定义Sampler

        在PyTorch中,可以通过继承torch.utils.data.Sampler类来自定义采样器(Sampler)类。自定义的采样器类可以控制数据集中每个样本的采样方式,例如随机采样、分块采样等。

class MySampler(Sampler):
    """
        自定义Sampler,在__iter__函数中定义indices的生成方式,也叫生成顺序
    """

    def __init__(self, labels):
        self.labels = np.array(labels)
        self.image_ids = []

    def __iter__(self):
        """
            在每个batch中包含的每个类别的数量相等
        :return:
        """
        indices = []
        counter = Counter(self.labels)
        # 统计数据量最多的类别
        most_common = counter.most_common(1)[0][1]
        # 统计每张图片在filenames这个列表中对应的索引编号
        for c in range(len(counter)):
            indices.append(np.where(self.labels == c)[0].tolist())

        # 所有类别通过复制的方式与最多的类别对齐
        for indice in indices:
            if len(indice) < most_common:
                indice.extend(random.choices(indice, k=most_common - len(indice)))
            random.shuffle(indice)

        # 依次从所有类别中分别取一张图片组成batch
        for ids in zip(*indices):
            self.image_ids.extend(list(ids))

        return iter(self.image_ids)

    def __len__(self):
        return len(self.image_ids)

        以上自定义Sampler控制在返回训练样本编号的逻辑,使得每个batch中的各类别数据量相等,Sampler返回训练样本的编号,然后使用Dataset的__getitem__方法取出对应的样本。

        使用以下代码片段测试自定义的Sampler的数据采样情况:

my_sampler = MySampler([1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0])
sample_labels = []
for x in my_sampler:
    print(x)
    sample_labels.append(my_sampler.labels[x])
print(sample_labels)
print(len(my_sampler))

自定义collate_fn函数

        在PyTorch中,自定义collate_fn函数可以用于对数据集中的数据进行整合和处理。当使用自定义采样器(Sampler)加载数据时,collate_fn函数会被自动调用来整合每个batch的数据。

def collate_fn(batch_data):
    """
        对batch中的图像使用mixup,并返回mixup之后的结果
    :param batch_data:
    :return:
    """

    def mixup_data(x, y, alpha=1.0, use_cuda=False):
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]

        if use_cuda:
            index = torch.randperm(batch_size).cuda()
        else:
            index = torch.randperm(batch_size)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]

        return mixed_x, y_a, y_b, lam

    batch_img = []
    batch_label = []
    for img, label in batch_data:
        batch_img.append(img)
        batch_label.append(label)

    batch_img = torch.stack(batch_img, dim=0)
    batch_label = torch.tensor(batch_label)
    # print(batch_img.shape, batch_label.shape)

    batch_img, batch_label_a, batch_label_b, batch_lam = mixup_data(batch_img, batch_label)
    return batch_img, batch_label_a, batch_label_b, batch_lam

        在以上自定义collate_fn函数中,我们在每个batch批量样本之间使用mixup数据增强,并返回mixup之后的增强数据以及对应的标签和参数。

自定义Dataset、Sampler、collate_fn,以及使用Dataloader的完整代码

# coding:utf-8

import torch
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
import torchvision
from torchvision.io import read_image
import random
import numpy as np
from matplotlib import pyplot as plt
from collections import Counter


class MyDataset(Dataset):
    """
        加载磁盘上的图像文件,并进行transform变换,返回变换后的图片和与之对应的标签编号
    """

    def __init__(self, filenames, labels, transforms_pipeline=None):
        super().__init__()
        # 所有图像的路径列表
        self.filenames = filenames
        # 所有图片对应的label标签编号,从0开始
        self.labels = labels
        # 图像预处理
        self.transforms_pipeline = transforms_pipeline

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filepath = self.filenames[idx]
        img = read_image(filepath, mode=torchvision.io.ImageReadMode.RGB)
        if self.transforms_pipeline:
            img = self.transforms_pipeline(img)
        return img, self.labels[idx]


def collate_fn(batch_data):
    """
        对batch中的图像使用mixup,并返回mixup之后的结果
    :param batch_data:
    :return:
    """

    def mixup_data(x, y, alpha=1.0, use_cuda=False):
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]

        if use_cuda:
            index = torch.randperm(batch_size).cuda()
        else:
            index = torch.randperm(batch_size)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]

        return mixed_x, y_a, y_b, lam

    batch_img = []
    batch_label = []
    for img, label in batch_data:
        batch_img.append(img)
        batch_label.append(label)

    batch_img = torch.stack(batch_img, dim=0)
    batch_label = torch.tensor(batch_label)
    # print(batch_img.shape, batch_label.shape)

    batch_img, batch_label_a, batch_label_b, batch_lam = mixup_data(batch_img, batch_label)
    return batch_img, batch_label_a, batch_label_b, batch_lam


class MySampler(Sampler):
    """
        自定义Sampler,在__iter__函数中定义indices的生成方式,也叫生成顺序
    """

    def __init__(self, labels):
        self.labels = np.array(labels)
        self.image_ids = []

    def __iter__(self):
        """
            在每个batch中包含的每个类别的数量相等
        :return:
        """
        indices = []
        counter = Counter(self.labels)
        # 统计数据量最多的类别
        most_common = counter.most_common(1)[0][1]
        # 统计每张图片在filenames这个列表中对应的索引编号
        for c in range(len(counter)):
            indices.append(np.where(self.labels == c)[0].tolist())

        # 所有类别通过复制的方式与最多的类别对齐
        for indice in indices:
            if len(indice) < most_common:
                indice.extend(random.choices(indice, k=most_common - len(indice)))
            random.shuffle(indice)

        # 依次从所有类别中分别取一张图片组成batch
        for ids in zip(*indices):
            self.image_ids.extend(list(ids))

        return iter(self.image_ids)

    def __len__(self):
        return len(self.image_ids)


## 测试自定义Sampler
# my_sampler = MySampler([1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0])
# sample_labels = []
# for x in my_sampler:
#     print(x)
#     sample_labels.append(my_sampler.labels[x])
# print(sample_labels)
# print(len(my_sampler))


transforms_pipeline = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
    ]
)

# 图像存放位置,其中包含两个目录,cat和dog,cat下存放猫的图片,dog下存放狗的图片
data_path = r"C:\WorkDir\PythonWorkspace\MusicRecognition\mixup-cifar10-main\data\cat_and_dog"
image_folder = torchvision.datasets.ImageFolder(data_path)
# image_folder.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_folder.samples)
# image_folder.classes image_folder.samples中存放的类别索引编号相对应
classes = image_folder.classes
# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_folder.samples:
    # print(image_path, label)
    filenames.append(image_path)
    labels.append(label)
print(filenames, labels)

# 使用自定义Dataset类加载磁盘上的图上数据
my_dataset = MyDataset(filenames, labels, transforms_pipeline)
# img, label = my_dataset[10]
# print(img.shape, label)

# 使用自定义collate_fn函数,在每个batch中进行mixup图片增强,并返回增强后的图片数据、标签、以及mixup系数
dataloader = DataLoader(
    my_dataset,
    batch_size=8,  # batch_size要能整除类别数
    shuffle=False,  # 使用sampler时,shuffle参数要设置为False
    sampler=MySampler(labels),  # 自定义Sampler,返回的batch中每种类别的数量相等
    batch_sampler=None,
    collate_fn=collate_fn  # 自定义collate_fn,其中执行mixup数据增强
)

for batch_img, batch_label_a, batch_label_b, batch_lam in dataloader:
    print(batch_img.shape, batch_label_a.shape, batch_label_b.shape, batch_lam)
    # batch中包含每个类别的数量相等,猫和狗都是4张
    # {0: 4, 1: 4}
    print(Counter(batch_label_a.detach().cpu().numpy().tolist()))
    break

for idx, img in enumerate(batch_img):
    plt.imshow(img.permute(1, 2, 0).int().clamp(min=0, max=255).detach().cpu().numpy())
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/600543.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

06 【Vue数据监视 v-model双向绑定】

1.Vue数据监视 1.1 问题演示 先来个案例引入一下&#xff1a; <!-- 准备好一个容器--> <div id"root"><h2>人员列表</h2><button click"updateMei">更新马冬梅的信息</button><ul><li v-for"(p,inde…

Markdown笔记应用程序Note Mark

什么是 Note Mark Note Mark 是一种轻量、快速、简约&#xff0c;基于网络的 Markdown 笔记应用程序。具有时尚且响应迅速的网络用户界面。 安装 在群晖上以 Docker 方式安装。 ghcr.io 镜像下载 官方的镜像没有发布在 docker hub&#xff0c;而是在 ghcr.io&#xff0c;所以…

总结了几百个ChatGPT模型的调教经验,确定不来看看?

目录 前言 chatgpt调教指南 提示词 1.清晰的问题或请求&#xff1a; 2.上下文设置&#xff1a; 3.具体的主题或领域&#xff1a; 4.陈述性问题&#xff1a; 5.追问和澄清&#xff1a; 6.限定问题范围&#xff1a; 角色扮演 充当 Linux 终端 担任产品经理 充当 SQL…

技术帖——飞凌嵌入式RK3588开发板推理模型转换及测试

RKNN&#xff08;Rockchip Neural Network&#xff09;是一种用于嵌入式设备的深度学习推理框架&#xff0c;它提供了一个端到端的解决方案&#xff0c;用于将训练好的深度学习模型转换为在嵌入式设备上运行的可执行文件。使用RKNN框架可以在嵌入式设备上高效地运行深度学习模型…

易基因:DNA羟甲基化和TET酶在胎盘发育和妊娠结局中的作用 | 深度综述

大家好&#xff0c;这里是专注表观组学十余年&#xff0c;领跑多组学科研服务的易基因。 胎盘是支持哺乳动物胚胎和胎儿发育所必需的临时器官。了解滋养层细胞分化和胎盘功能的分子机制可能有助于改善产科并发症的诊断和治疗。印迹基因是调控胎盘发育的基础&#xff0c;表观遗…

chatgpt赋能python:使用Python编写数据接口:如何让您的网站更具吸引力和效率

使用Python编写数据接口&#xff1a;如何让您的网站更具吸引力和效率 在当今数字时代&#xff0c;大多数公司都希望能够从用户生成的数据中收集和分析信息&#xff0c;以了解他们的客户群体并提高他们的营销策略。为此&#xff0c;开发数据接口成为了一项对于互联网公司不可或…

新文本检测算法TextFuseNet

TextFuseNet: Scene Text Detection with Richer Fused Features 自然场景中任意形状文本检测是一项极具挑战性的任务&#xff0c;与现有的仅基于有限特征表示感知文本的文本检测方法不同&#xff0c;本文提出了一种新的框架&#xff0c;即 TextFuseNet &#xff0c;以利用融合…

网络开发过程详细知识点

网络生命周期至少包括系统构思与计划、分析和设计、运行和维护的过程。 常见的迭代周期分为四阶段周期、五阶段周期、六阶段周期。 网络开发过程根据五阶段迭代周期模型可被分为五个阶段&#xff1a; 需求分析、现有网络分析、确定网络逻辑结构、确定网络物理结构、安装与维护。…

linux实践php8.2加laravel-cotane和swoole服务器

php8.2 composer -v 报错&#xff1a; Deprecation Notice: strlen(): Passing null to parameter #1 ($string) of type string is deprecated in phar:///usr/bin/composer/vendor/symfony/console/Descriptor/TextDescriptor.php:290 解决方法可以升级下composer&#xff1…

计算机中小数的存储

十进制小数怎么转成二进制小数&#xff1f;怎么在计算机中存储float&#xff1f; 计算机中存储的二进制小数&#xff08;float&#xff09;怎么转成十进制小数&#xff1f;

法规标准-ISO 20900标准解读

ISO 20900是做什么的&#xff1f; ISO 20900全名为智能交通系统-部分自动泊车系统(PAPS)-性能要求和试验程序&#xff0c;其中主要是对PAPS系统的功能要求、性能要求及测试步骤进行了介绍 PAPS类型 I类型PAPS系统反应 II类型PAPS系统反应 一般要求 运行期间的最大速度 系统…

05 【绑定样式 条件渲染 列表渲染】

1.绑定样式 1.1 class样式 写法 :classxxx xxx可以是字符串、对象、数组。 所以分为三种写法:字符串写法、对象写法、数组写法。 1.1.1 字符串写法 字符串写法适用于: 类名不确定,要动态获取 <style>.normal{background-color: skyblue;} </style><!-- 准备…

DeSTSeg:用于异常检测的分割网络引导去噪学生教师模型(CVPR2023)

文章目录 DeSTSeg: Segmentation Guided Denoising Student-Teacher for Anomaly Detection摘要本文方法Synthetic Anomaly GenerationDenoising Student-Teacher Network分割网络推理 实验结果消融实验 DeSTSeg: Segmentation Guided Denoising Student-Teacher for Anomaly D…

倾斜摄影三维模型数据的几何坐标变换与点云重建并行计算技术探讨

倾斜摄影三维模型数据的几何坐标变换与点云重建并行计算技术探讨 倾斜摄影三维模型数据的几何坐标变换和点云重建是一项大规模计算密集型任务&#xff0c;需要消耗大量的计算资源。并行计算技术可以将这些任务分解为多个子任务&#xff0c;并在多个CPU或GPU上同时运行&#xff…

一起学SF框架系列4.6-模块context-AbstractApplicationContext

org.springframework.context.ApplicationContext接口表示Spring IoC容器&#xff0c;负责实例化、配置和组装bean。容器通过读取配置元数据来获取关于实例化、配置和组装哪些对象的指令。配置元数据以XML、Java注释或Java代码表示。它允许您表达组成应用程序的对象以及这些对象…

微信小程序的登录流程

一、背景 传统的web开发实现登陆功能&#xff0c;一般的做法是输入账号密码、或者输入手机号及短信验证码进行登录。 服务端校验用户信息通过之后&#xff0c;下发一个代表登录态的 token 给客户端&#xff0c;以便进行后续的交互,每当token过期&#xff0c;用户都需要重新登…

深度学习训练营N1周:Pytorch文本分类入门

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制 NLP的功能&#xff1a; 本周使用AG News数据集进行文本分类。实现过程分为前期准备、代码实战、使用测试数据集评估模型和总结四个部分。…

chatgpt赋能python:Python冒泡排序详解

Python冒泡排序详解 介绍 Python是一门强大的编程语言&#xff0c;它在数据科学、机器学习、Web开发等领域都有广泛的应用。其中&#xff0c;排序算法是编程中一个重要的话题&#xff0c;冒泡排序也是最基本的排序算法之一。本文将详解Python冒泡排序的实现方法和优化技巧&am…

chatgpt赋能python:利用Python编写模拟器:一种循序渐进的方法

利用Python编写模拟器&#xff1a;一种循序渐进的方法 模拟器是一种用于模拟计算机硬件或软件的程序。它模拟了真实设备的功能&#xff0c;可以帮助开发人员进行测试和调试&#xff0c;以及提供一种环境来设计和验证新的算法和协议。Python是一种广泛使用的编程语言&#xff0…

计讯物联宝贝王手工大赛投票结果正式揭晓,速速围观!

在孩子的想象世界中&#xff0c; 生活中的可爱 可以是专属六一的蛋糕&#xff0c; 可以是创意手绘手摇扇&#xff0c; 可以是萌萌可爱的花束&#xff0c; 可以是未来超智能机器人&#xff0c; 可以是无人航天器模型…… 他们的想象&#xff0c; 是尚未被世俗沾染的赤忱之…