torch.utils.data

news2025/1/20 11:54:25

整体架构

平时使用 pytorch 加载数据时大概是这样的:

import numpy as np
from torch.utils.data import Dataset, DataLoader

class ExampleDataset(Dataset):
	def __init__(self):
		self.data = [1, 2, 3, 4, 5]

	def __getitem__(self, idx):
		return self.data[idx]

	def __len__(self):
		return len(self.data)

def collate_fn(batch):
	return np.array(batch)

dataset = ExampleDataset()  # create the dataset
dataloader = DataLoader(
	dataset=dataset,
	batch_size=2,
	shuffle=True,
	num_workers=4,
	collate_fn=collate_fn
)
for datapoint in dataloader:
	print(datapoint)
  1. 继承 Dataset 类,定义一个迭代器,包含两个魔法方法:__getitem__(self, idx)__len__(self),分别实现如何获取一条数据和如何设定数据长度;
  2. 定义 collate_fn 函数,设定如何组织一个 batch
  3. 实例化 Dataset,并和 collate_fn 一起传入 DataLoader,参数 batch_size 设置批大小、shuffle 设置是否打乱、num_workers 设置并行加载数据的进程数。

然而,背后到底干了什么,我们不清楚,甚至遇到 DataLoader 的如 samplerbatch_samplerworker_init_fn 的其他参数,就会懵逼。那就看一看官方文档,了解一下 torch.utils.data 是如何工作的。


上图是数据加载的整体框架图,官网说 DataLoader 组合datasetsampler,多个 workers 根据 dataset 提供的数据副本sampler 提供的 keys 并行地加载数据,并通过 collate_fn 组成 batch 供用户迭代。需要注意的有:

  1. 每个 worker 持有数据的一个副本,故占用内存主线程内存 * num_workers”;
  2. 即使用户不提供 sampler 对象 (通常不提供),DataLoader 也会根据 shuffle 参数创建一个默认的 sampler 对象;一旦提供了,其前路的 shuffle 参数不能为 True (不提供就好);
  3. 即使用户不提供 batch_sampler 对象 (通常不提供),DataLoader 也会根据 batch_sampler, drop_last 参数创建一个默认的 batch_sampler 对象;一旦提供了,其前路的 shuffle, drop_last 不能为 Truebatch_size 必须为 1 1 1sampler 必须为 None,因为创建 BatchSampler 时已经有了这些参数;

    本质上是把创建 batch_sampler 的活拉出来由用户在 DataLoader 外自定义地做了。

Dataset

分为两种:map-styleiterable-style。前者的数据可通过 [idx or key] 访问,后者的数据只能通过迭代器 next 一个个访问。所以上面架构中的采样器是对于 map-style 数据集说的iterable-style 的数据集的访问顺序由迭代器决定。

Sampler

torch.utils.data.Sampler 的子类或 Iterable,两个例子:

class AccedingSequenceLengthSampler(tu_data.Sampler[int]):
	def __init__(self, data: List[str]) -> None:
		super().__init__()
		self.data = data

	def __len__(self) -> int:
		return len(self.data)

	def __iter__(self) -> Iterator[int]:
		"""
		:return: 实现了按数据长短顺序访问数据集
		"""
		sizes = torch.tensor([len(x) for x in self.data])
		yield from torch.argsort(sizes).tolist()


class AccedingSequenceLengthBatchSampler(tu_data.Sampler[List[int]]):
	def __init__(self, data: List[str], batch_size: int) -> None:
		super().__init__()
		self.data = data
		self.batch_size = batch_size

	def __len__(self) -> int:
		return (len(self.data) + self.batch_size - 1) // self.batch_size

	def __iter__(self) -> Iterator[List[int]]:
		sizes = torch.tensor([len(x) for x in self.data])
		for batch in torch.chunk(torch.argsort(sizes), len(self)):  # 按块遍历
			yield batch.tolist()

Batch

batch_sampler 提供一批下标,取得一批数据后由 collate_fn 将这批数据整合:

if collate_fn is None:
	if self._auto_collation:
		collate_fn = _utils.collate.default_collate
	else:  # self.batch_sampler is None: (batch_size is None) and (batch_sampler is None)
		collate_fn = _utils.collate.default_convert

分两种情况:

  • automatic batching is disabled:调用 default_convert 函数简单地将 NumPy arrays 转化为 PyTorch Tensor;
  • automatic batching is enabled:调用 default_collate 函数,转化会变得复杂一点:
from torch.utils import data as tu_data
import collections

# %% Example with a batch of `int`s:
tu_data.default_collate([0, 1, 2, 3])
# tensor([0, 1, 2, 3])

# %% Example with a batch of `str`s:
tu_data.default_collate(['a', 'b', 'c'])
# ['a', 'b', 'c']

# %% Example with `Map` inside the batch:
tu_data.default_collate([
	{'A': 0, 'B': 1},
	{'A': 100, 'B': 100}
])
# {'A': tensor([0, 100]), 'B': tensor([1, 100])}, 同 key 的合并了

# %% Example with `NamedTuple` inside the batch:
Point = collections.namedtuple('Point', ['x', 'y'])
tu_data.default_collate([Point(0, 0), Point(1, 1)])
# Point(x=tensor([0, 1]), y=tensor([0, 1])), 同 name 的合并了, 大概和 dict 一样吧

# %% Example with `Tuple` inside the batch:
tu_data.default_collate([(0, 1), (2, 3)])
# [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate

# %% Example with `List` inside the batch:
tu_data.default_collate([[0, 1], [2, 3]])  # [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate, 并没有变成二维 tensor

Multi-process Data Loading

dataset, collate_fn, and worker_init_fn are passed to each worker,大概能说明 batch 是在子进程内部合成的。

有一个需要注意的地方是内存增长问题,当 __get_item__(self, key) 访问数据时,由于 Python 对象的 refcount 机制,数据会不断地复制,从而内存爆炸。但这里说解决 number of workers * size of parent process 问题,就不追究了,反正尽量用 numpy 或 pytorch tensor 吧。
iterable-style datasets 的随机性

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1455194.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GitKraken Create Repository and Clone不可点击

问题 GitKraken Create Repository and Clone不可点击 详细问题 笔者第一次使用GitKraken,在创建仓库时,填写完成仓库初始化后。发现Create Repository and Clone不可点击。 解决方案 选择Where to clone to位置 产生原因 在创建仓库时&#xff0…

IO 作业 24/2/18

1> 使用fgets统计给定文件的行数 #include <stdio.h> #include <stdlib.h> #include <string.h> int main(int argc, const char *argv[]) {//定义文件指针FILE *fpNULL;//打开文件&#xff08;只读&#xff09;if((fpfopen("./test.txt",&quo…

机器人内部传感器-位置传感器-电位器式位置传感器

位置传感器 位置感觉是机器人最基本的感觉要求&#xff0c;可以通过多种传感器来实现。位置传感器包括位置和角度检测传感器。常用的机器人位置传感器有电位器式、光电式、电感式、电容式、霍尔元件式、磁栅式及机械式位置传感器等。机器人各关节和连杆的运动定位精度要求、重…

suse15 sp3-sp5离线安装中安装FIO

没有网络的情况下&#xff0c;离线安装相对比较困难一点&#xff0c;所有需要提前下载相应的RPM安装包 FIO 安装包链接如下&#xff1a; Install package benchmark / fio 正常安装的时候&#xff0c;会出现问题 如下&#xff1a; google下 https://opensuse.pkgs.org/15.5/…

WorkPlus Meet助力企业建立安全可靠的私有化视频会议平台

企业需要保护敏感信息和保证会议质量的同时&#xff0c;提高会议的效率和协作水平。作为一款私有化视频会议软件系统&#xff0c;WorkPlus Meet以其卓越的性能和高度私密的特性&#xff0c;助力企业打造安全可靠的私有化视频会议平台。 为何选择WorkPlus Meet作为私有化视频会议…

软件实例分享,酒店酒水寄存管理系统软件教程

软件实例分享&#xff0c;酒店酒水寄存管理系统软件教程 一、前言 以下软件教程以 佳易王酒水寄存管理系统软件V16.0为例说明 软件文件下载可以点击最下方官网卡片——软件下载——试用版软件下载 1、寄存的商品名称可以预先设置 2、寄存人可以使用手.机号识别 3、会员充值…

flowpilot Pxiel 6 redmi K30 Pro

Installation flowdriveai/flowpilot Wiki GitHub Flowpilot can be installed on: Android phone Non-rooted running Android 10Android 11Android 12Rooted running Android 13 requires rootDesktop pc with Ubuntu > 20.04. 安装Termux https://f-droid.org/repo…

Sentinel从入门到“精通”,从源码层面学习Sentinel

B站视频讲解 文章目录 一、安装1、原生使用2、dashboard整合2-1、非starter整合2-1-1、公共2-1-2、Filter2-1-3、AOP2-2、starter 整合 3、总结 二、常见的策略1、限流1-1、基于QPS 限流1-2、基于线程数限流 2、降级2-1、慢调用比例2-2、异常数&#xff08;限流异常不算&#x…

《游戏引擎架构》--学习

内存管理 优化动态内存分配 维持最低限度的堆分配&#xff0c;并且永不在紧凑循环中使用堆分配 容器 迭代器 未完待续。。。

深度学习与计算机视觉 | 实用CV开源项目汇总(有github代码链接,建议收藏!)

本文来源公众号“深度学习与计算机视觉”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;【建议收藏】实用CV开源项目汇总&#xff08;文末有彩蛋~&#xff09; 01 Trace.moe 图像反向搜索动漫场景&#xff0c;使用动漫截图搜索该…

债券专题二:可转债估值-二叉树模型

1. 模型背景 由于可转债自身的属性较多&#xff0c;因此对其定价的难度也会加大&#xff0c;在诸多影响因素中&#xff0c;未来的股价占比最高。由于股价的不可预测性&#xff0c;导致了可转债的定价在实际交易中作用非常有限。随着可转债发行数量和规模的增大&#xff0c;越…

PHP支持的伪协议

php.ini参数设置 在php.ini里有两个重要的参数allow_url_fopen、allow_url_include。 allow_url_fopen:默认值是ON。允许url里的封装协议访问文件&#xff1b; allow_url_include:默认值是OFF。不允许包含url里的封装协议包含文件&#xff1b; 各协议的利用条件和方法 php:/…

notepad++打开文本文件乱码的解决办法

目录 第一步 在编码菜单栏下选择GB2312中文。如果已经选了忽略这一步 第二步 点击编码&#xff0c;红框圈出来的一个个试。我切换到UTF-8编码就正常了。 乱码如图。下面分享我的解决办法 第一步 在编码菜单栏下选择GB2312中文。如果已经选了忽略这一步 第二步 点击编码&#…

伦敦金和现货黄金是一回事吗?

想进入黄金市场的朋友&#xff0c;在网上一搜相关的讯息&#xff0c;可能就懵了。这个市场中好像有几个品种&#xff0c;又是伦敦金又是现货黄金什么的。很多新手投资者想知道&#xff0c;这些伦敦金、现货黄金分别是指什么&#xff0c;下面我们就来讨论一下。 实际上&#xff…

Open CASCADE学习|曲线向曲面投影

在三维空间中&#xff0c;将曲线向曲面投影通常涉及复杂的几何计算。这个过程可以通过多种方法实现&#xff0c;但最常见的是使用数学和几何库&#xff0c;如OpenCASCADE&#xff0c;来处理这些计算。 在OpenCASCADE中&#xff0c;投影曲线到曲面通常涉及以下步骤&#xff1a;…

Vue项目启动过程全记录(node.js运行环境搭建)

一、安装node.js并配置环境变量 1、安装node.js 从Node.js官网下载安装包并安装。然后在安装后的目录&#xff08;如果是下载的压缩文件&#xff0c;则是解压缩的目录&#xff09;下新建node_global和node_cache这两个文件夹。 node_global&#xff1a;npm全局安装位置 node_…

Python 字符串格式化输出

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站零基础入门的AI学习网站~。 前言 字符串格式化是编程中一个常见的需求&#xff0c;它可以们将不同类型的数据&#xff08;如数字、文本、日…

【ansible】认识ansible,了解常用的模块

目录 一、ansible是什么&#xff1f; 二、ansible的特点&#xff1f; 三、ansible与其他运维工具的对比 四、ansible的环境部署 第一步&#xff1a;配置主机清单 第二步&#xff1a;完成密钥对免密登录 五、ansible基于命令行完成常用的模块学习 模块1&#xff1a;comma…

Shiro反弹shell和权限绕过含工具包

★★免责声明★★ 文章中涉及的程序(方法)可能带有攻击性&#xff0c;仅供安全研究与学习之用&#xff0c;读者将信息做其他用途&#xff0c;由Ta承担全部法律及连带责任&#xff0c;文章作者不承担任何法律及连带责任。 1、前言 反序列化漏洞原理和Shiro反序列化漏洞原理请参…

AI绘画图生图怎么用?

AI绘画图生图是指利用人工智能技术&#xff0c;将一张已有的图片转化为另一张具有艺术风格的新图片的过程。这种技术可以应用于多个领域&#xff0c;如室内设计等。 在使用AI绘画图生图功能时&#xff0c;用户需要选择一张参考图片&#xff0c;然后设置生成图片的风格、尺寸、数…