Pytorch:torch.utils.data.DataLoader()

news2025/1/23 10:32:18

如果读者正在从事深度学习的项目,通常大部分时间都花在了处理数据上,而不是神经网络上。因为数据就像是网络的燃料:它越合适,结果就越快、越准确!神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
参考:https://zhuanlan.zhihu.com/p/596730297

DataLoader加载和迭代数据集

Dataloader本质是一个迭代器对象,也就是可以通过for batch_idx,batch_dict in dataloader 来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)

from torch.utils.data import DataLoader

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

以U-Net中的代码为例:
具体详见:U-Net代码复现

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

1. 数据集

**dataset (Dataset) ** – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. getitem && len )

2. 对数据进行批处理

batch_size (int, optional)how many samples per batch to load(default: 1).

3. 在 CUDA 张量上加载数据

pin_memory(bool, optional)If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.

pin_memory参数直接将数据集加载为 CUDA 张量。它是一个可选参数,接受一个布尔值;如果设置为True,会在返回张量之前张量复制到 CUDA 固定内存中。这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。

通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用pin_memory可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。

需要注意的是,使用pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory并不会带来明显的加速效果。

4.允许多进程

num_workers (int, optional)how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为0,即在主进程中进行数据加载,而不使用额外的子进程。

以下是我看到的一个解释,原文链接:https://blog.csdn.net/vonct/article/details/130263743
下面说一下个人的理解,在初始化 dataloader对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)

每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个(即一帧),并将其放到该worker独有的内存队列中。 要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。

当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。

这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解

在这里插入图片描述

5.合并数据集

collate_fn (Callable, optional)merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

整合多个样本到一个batch时需要调用的函数,当 getitem 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求。
以U-Net为例:

def __getitem__(self, idx):
        name = self.ids[idx]
        mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
        img_file = list(self.images_dir.glob(name + '.*'))

        assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
        assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
        mask = load_image(mask_file[0])
        img = load_image(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'

        img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
        mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)

        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }

getitem 返回的是一个包含image和mask的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包(待补充。。。)

6.数据采样

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shufflemust not be specified.

sampler的主要作用是控制样本的采样顺序,并提供样本的索引。在默认情况下,dataloader使用的是SequentialSampler,它按照数据集的顺序依次提取样本,但在某些情况下,我们可能需要自定义采样顺序。比如说想从队尾提取数据。

比如,当我们处理非常大的数据集时,为了提高训练效率,可能需要对数据进行分布式采样,这时候就需要使用DistributedSampler。DistributedSampler会将数据集划分成多个子集,每个子集分配给不同的进程进行采样。在这种情况下,如果使用默认的SequentialSampler,可能会导致各个进程采样到相同的数据,从而降低训练效率。

此外,还有一些自定义的sampler,比如随机采样器(RandomSampler)和加权采样器(WeightedRandomSampler),它们可以按照不同的采样策略对数据集进行采样,从而满足不同的训练需求。

因此,根据不同的训练需求,我们可能需要自定义sampler来控制数据的采样顺序。

原文链接:https://blog.csdn.net/vonct/article/details/130263743

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1265884.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【人工智能】人工智能的技术研究与安全问题的深入讨论

前言 人工智能(Artificial Intelligence),英文缩写为AI。 它是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。人工智能是新一轮科技革命和产业变革的重要驱动力量。 📕作者简介&#x…

vscode注释插件「koroFileHeader」

前言 在vscode上进行前端开发,有几个流行的注释插件: Better CommentsTodo TreekoroFileHeaderDocument ThisAuto Comment Blocks 在上面的插件中我选择 koroFileHeader 做推荐,原因一是使用人数比较多(最多的是 Better Commen…

029 - STM32学习笔记 - ADC(三) 独立模式单通道DMA采集

029 - STM32学习笔记 - 单通道DMA采集(三) 单通道ADC采集在上节中学习完了,这节在上节的内容基础上,学习单通道DMA采集。程序代码以上节的为基础,需要删除NVIC配置函数、中段服务子程序、R_ADC_Mode_Config()函数中使能…

探索Python内置类属性__repr__:展示对象的魅力与实用性

概要 在Python中,每个对象都有一个内置的__repr__属性,它提供了对象的字符串表示形式。这个特殊的属性在调试、日志记录和交互式会话等场景中非常有用。本文将详细介绍__repr__属性的使用教程,包括定义、常见应用场景和注意事项,…

机器人向前冲

欢迎来到程序小院 机器人向前冲 玩法:一直走动的机器人,点击鼠标左键进行跳跃,跳过不同的匝道,掉下去即为游戏接续, 碰到匝道铁钉游戏结束,一直往前冲吧^^。开始游戏https://www.ormcc.com/play/gameStart…

C++基础 -10- 类的构造函数

类的构造函数类型一 使用this指针给类内参数赋值 class rlxy {public:int a;rlxy(int a, int b, int c){this->aa;this->bb;this->cc;cout << "rlxy" << endl;}protected:int b;private:int c; };int main() {rlxy ss(10, 20, 30); }类的构造…

使用Accelerate库在多GPU上进行LLM推理

大型语言模型(llm)已经彻底改变了自然语言处理领域。随着这些模型在规模和复杂性上的增长&#xff0c;推理的计算需求也显著增加。为了应对这一挑战利用多个gpu变得至关重要。 所以本文将在多个gpu上并行执行推理&#xff0c;主要包括&#xff1a;Accelerate库介绍&#xff0c;…

在Rust中处理命令行参数和环境变量

1.摘要 Rust的命令行和环境变量处理在标准库中提供了一整套实现方法, 在本文中除了探索标准库的使用方法之外, 也在不断适应Rust独有的语法特点。在本文中, 我们通过标准库函数的返回值熟悉了迭代器的使用方法, 操作迭代器精确控制保存的内容, 包括字符串和键值对的使用方法。…

SpringBoot整合EasyExcel实现复杂Excel表格的导入导出功能

文章目录 &#x1f389;SpringBoot整合EasyExcel实现复杂Excel表格的导入&导出功能 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#x1f379;✨博客主页&#xff1a;IT陈寒的博客&#x1f388;该系列文章专栏&#xff1a;架构设计&#x1f4dc;其他专栏&#xff1a;Java学习路线 Jav…

mysql 性能排查

mysql 下常见遇到的问题有&#xff0c;mysql连接池耗尽&#xff0c;死锁、慢查、未提交的事务。等等我们可能需要看&#xff1b;我们想要查看的可能有 1.当前连接池连接了哪些客户端&#xff0c;进行了哪些操作 2.当前造成死锁的语句有哪些&#xff0c;是哪个客户端上的&#x…

2023网络安全产业图谱

1. 前言 2023年7月10日&#xff0c;嘶吼安全产业研究院联合国家网络安全产业园区&#xff08;通州园&#xff09;正式发布《嘶吼2023网络安全产业图谱》。 嘶吼安全产业研究院根据当前网络安全发展规划与趋势发布《嘶吼2023网络安全产业图谱》调研&#xff0c;旨在进一步了解…

2020年6月16日 Go生态洞察:泛型的下一步

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

选择跨网数据摆渡系统时,你最关注的功能是哪些?

为什么要选择跨网数据摆渡系统呢&#xff1f;因为做了网络隔离后&#xff0c;要有数据交互。那为什么要做网络隔离呢&#xff1f;主要还是安全方面的考虑&#xff0c;一般有以下几个原因&#xff1a; 1、数据安全保护&#xff1a;对于一些重要数据&#xff0c;比如代码数据、隐…

leetCode 39.组合总和 + 回溯算法 + 剪枝 + 图解 + 笔记

39. 组合总和 - 力扣&#xff08;LeetCode&#xff09; 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target &#xff0c;找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 &#xff0c;并以列表形式返回。你可以按 任意顺序 返回这些组合 can…

IDEA 配置 gradle6.8.3 解决导入gradle项目下载太慢问题

由于平时用的是springboot 2.7 这里下载gradle-6.8.3 Gradle官网地址&#xff1a;https://services.gradle.org/distributions/ 1.下载gradle后&#xff0c;配置环境变量 GRADLE_HOME {gradle 文件路径} GRADLE_USER_HOME {jar下载路径&#xff0c;可以放maven jar保存路径…

RabbitMQ高级特性2 、TTL、死信队列和延迟队列

MQ高级特性 1.削峰 设置 消费者 测试 添加多条消息 拉取消息 每隔20秒拉取一次 一次拉取五条 然后在20秒内一条一条消费 TTL Time To Live&#xff08;存活时间/过期时间&#xff09;。 当消息到达存活时间后&#xff0c;还没有被消费&#xff0c;会被自动清除。 RabbitMQ…

浏览器触发下载Excel文件-Java实现

目录 1:引入maven 2:代码实现 3.导出通讯录信息到Excel文件 4.生成并下载Excel文件部分解释 1:引入maven 添加依赖:首先,在你的项目中添加EasyExcel库的依赖。你可以在项目的构建文件(如Maven的pom.xml)中添加以下依赖项:<dependency><groupId>com.alib…

医疗影像数据集—CT、X光、骨折、阿尔茨海默病MRI、肺部、肿瘤疾病等图像数据集

最近收集了一大波关于CT、X光等医疗方面的数据集包含骨折、阿尔茨海默病MRI、肺部疾病等类型的医疗影像数据&#xff0c;废话不多说&#xff0c;给大家逐一介绍&#xff01;&#xff01; 1、彩色预处理阿尔茨海默病MRI(磁共振成像)图像数据集 彩色预处理阿尔茨海默病MRI(磁共…

平凯星辰携手教育部教育管理信息中心,助力普惠教育数字化

近日&#xff0c;企业级开源分布式数据库厂商平凯星辰与教育部教育管理信息中心达成合作&#xff0c;TiDB 分布式数据库为全国中小学管理服务平台提供全栈服务。双方将携手深入探索领先的数据库技术在教育行业的新场景与新应用&#xff0c;既夯实教育数字化底座&#xff0c;助力…

“抓机遇,促发展”2024亚洲国际人工智能展览会(世亚智博会)

随着人工智能技术的飞速发展&#xff0c;我们正在见证一个全新的时代。2024年即将到来&#xff0c;这一年是人工智能创新将重塑传统界限的一年。从全球领先的科技大国到各类企业&#xff0c;人工智能技术正在以前所未有的速度融入我们的日常生活&#xff0c;推动行业走向未来&a…