Pytorch深度学习-----DataLoader的用法

news2025/1/8 5:51:42

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)


本文目录

  • 系列文章目录
  • 一、DataLoader是什么?
  • 二、使用步骤
    • 1.相关参数
    • 2.引入库
    • 3.创建数据(使用CIFAR10为例)
    • 4.创建DataLoader实例
    • 5.在Tensorboard中显示即完整代码如下


一、DataLoader是什么?

DataLoader是Pytorch中用来处理模型输入数据的一个工具类。组合了数据集(dataset) + 采样器(sampler),如果把Dataset比作一副扑克牌,则DataLoader就是每次手中处理的某一批扑克牌,然后每一批取多少张,总共能取多少批,用不用打乱顺序等,都可以在创建DataLoader时从参数自行设定。

二、使用步骤

1.相关参数

class DataLoader(Generic[T_co]):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.

    The :class:`~torch.utils.data.DataLoader` supports both map-style and
    iterable-style datasets with single- or multi-process loading, customizing
    loading order and optional automatic batching (collation) and memory pinning.

    See :py:mod:`torch.utils.data` documentation page for more details.

    Args:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
        sampler (Sampler or Iterable, optional): defines the strategy to draw
            samples from the dataset. Can be any ``Iterable`` with ``__len__``
            implemented. If specified, :attr:`shuffle` must not be specified.
        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
            returns a batch of indices at a time. Mutually exclusive with
            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
            and :attr:`drop_last`.
        num_workers (int, optional): how many subprocesses to use for data
            loading. ``0`` means that the data will be loaded in the main process.
            (default: ``0``)
        collate_fn (Callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.
        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
            see the example below.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
        worker_init_fn (Callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``)
        generator (torch.Generator, optional): If not ``None``, this RNG will be used
            by RandomSampler to generate random indexes and multiprocessing to generate
            `base_seed` for workers. (default: ``None``)
        prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
            in advance by each worker. ``2`` means there will be a total of
            2 * num_workers batches prefetched across all workers. (default value depends
            on the set value for num_workers. If value of num_workers=0 default is ``None``.
            Otherwise if value of num_workers>0 default is ``2``).
        persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
            the worker processes after a dataset has been consumed once. This allows to
            maintain the workers `Dataset` instances alive. (default: ``False``)
        pin_memory_device (str, optional): the data loader will copy Tensors
            into device pinned memory before returning them if pin_memory is set to true.

在上述中共有15个参数,我们常用的有如下5个参数

dataset (Dataset)– 表示要读取的数据集

batch_size (python:int, optional) – 表示每次从数据集中取多少个数据

shuffle (bool, optional) –表示是否为乱序取出

num_workers (python:int, optional) – 表示是否多进程读取数据(默认为0);

drop_last (bool, optional) – 表示当样本数不能被batchsize整除时(即总数据集/batch_size 不能除尽,有余数时),最后一批数据(余数)是否舍弃(default:
False)

pin_memory(bool, optional) - 如果为True会将数据放置到GPU上去(默认为false)

2.引入库

from torch.utils.data import DataLoader

3.创建数据(使用CIFAR10为例)

创建CIFAR10的测试集

test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

4.创建DataLoader实例

# 创建DataLoader实例
test_loader = DataLoader(
    dataset=test_set, # 引入数据集
    batch_size=4, # 每次取4个数据
    shuffle=True, # 打乱顺序
    num_workers=0, # 非多进程
    drop_last=False # 最后数据(余数)不舍弃
)

几点解释
以此次一批数据为4为例
一个批次dataloader[0]就是
img0,target0 = dateset[0] . . . img3,target3 = dateset[3]
总共4个数据
故,
dataloader会将上面的img0……img3进行打包成imgs
target0……target3进行打包成target
如下小土堆的图所示
在这里插入图片描述

5.在Tensorboard中显示即完整代码如下

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备测试集
test_set = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

# 创建test_loader实例
test_loader = DataLoader(
    dataset=test_set, # 引入数据集
    batch_size=4, # 每次取4个数据
    shuffle=True, # 打乱顺序
    num_workers=0, # 非多进程
    drop_last=False # 最后数据(余数)不舍弃
)

img,index = test_set[0]
print(img.shape) # 查看图片大小 torch.Size([3, 32, 32]) C h w,即三通道 32*32
print(index) # 查看图片标签
# 遍历test_loader
for data in test_loader:
    img,target = data
    print(img.shape) # 查看图片信息torch.Size([4, 3, 32, 32])表示一次4张图片,图片为3通道RGB,大小为32*32
    print(target)  # tensor([4, 9, 8, 8])表示4张图片的target
# 在tensorboard 中显示
writer = SummaryWriter("logs")
step = 0
for data in test_loader:
    img, target = data
    writer.add_images("test_loader",img,step)
    step = step+1
writer.close()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/796781.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Golang】基于OAuth2.0微信扫码实现客户端用户登录(原理+代码实现+视频讲解)

前言: 细心汇总,包括原理+配置+代码详细实现 文章目录 原理讲解什么是OAuth2.0解决方案授权码模式讲解认证流程Go语言实现微信扫码登录1. 内网穿透配置2. 微信测试账号申请3. 验证和微信服务器连接二维码生成回调地址测试原理讲解 什么是OAuth2.0 OAuth 2.0是一种授权协议,…

JavaScript学习 --消息摘要算法

消息摘要算法(也称哈希算法)是一种将任意大小的数据转换为一个固定大小的数据序列的算法。在JavaScript中,常见的消息摘要算法包括MD5、SHA-1、SHA-256等。它们适用于安全传输敏感数据、防篡改数据等场景。在本篇博客中,我们将介绍…

slurm/sbatch/srun 多步骤串行运行多个依赖性任务

在slurm系统下,有时候需要按步骤运行A、B、C三个任务,但是直接写在脚本里会同时提交,所以需要建立依赖关系。 错误做法: 搜索网上做法及slurm串行教程,做法多为如下,使用bash或python来按顺序/循环内来串…

顺序表详解

💓博主个人主页:不是笨小孩👀 ⏩专栏分类:数据结构与算法👀 🚚代码仓库:笨小孩的代码库👀 ⏩社区:不是笨小孩👀 🌹欢迎大家三连关注,一起学习,一起进步&#…

NetApp FAS2750 和 FAS2820:适用于分布式企业和从远程到核心的 FAS

NetApp FAS2750 和 FAS2820:适用于分布式企业和从远程到核心的 FAS 拥有分布式企业和多个办公位置的客户希望使用这些系统进行虚拟化,以及为大型 FAS 和 AFF 系统提供简单且经济高效的备份和灾难恢复。 为什么要从 NetApp FAS 系列中选择一个型号&…

LLM / Python - json 使用详解

目录 一.引言 二.json 方法 1.json.dumps 2.json.dump 3.json.loads 4.json.load 三.json 参数 1.ensure_ascii 2.allow_nan 3.indent 4.sortKeys 5.Other 四.LLM 数据构建 1.json 数据构建 2.Train.py 五.总结 一.引言 上文中我们介绍了 LLama2-Chinese 的简…

ipad手写笔有必要买原装吗?质量好苹果平板平替笔推荐

因为iPad平板的强大,使得很多人群都用上了iPad,而且还在不断的普及。不管是用于绘画或者学习记笔记,都非常好用,但要是用来看电视剧玩游戏就没那么有价值了。如果你不打算购买昂贵的苹果电容笔,或者只是为了记录&#…

“数字中华 点亮未来”中华线上客户节 盛大开幕

2023年是中华保险数字化转型落地之年,峥嵘37载,中华保险在数字化转型上已经涌现了一批彰显辨识度、具有影响力的应用成果。7月15日,中华保险围绕数字化转型之路开展以“数字中华 点亮未来”为主题的37周年线上客户节活动,倾力打造…

直播平台源码开发提高直播质量的关键:视频编码和解码技术

在互联网日益发展的今天,直播平台成为人们互联网生活的主力军,直播平台功能的多样化与智能化使我们的生活有了极大地改变,比如短视频功能,它让我们既可以随时随地去发布自己所拍摄到的东西让世界各地的用户看到,也能让…

融合正余弦和折射反向学习的北方苍鹰优化算法,与金鹰/蜣螂/白鲸/霜冰算法对比...

今天的主角是:融合正余弦和折射反向学习的北方苍鹰优化算法(SCNGO),算法由作者自行改进,目前应该没有文献这样做。 改进策略参照的上一期改进的麻雀优化算法,改进点如下: ①采用折射反向学习策略初始化北方苍鹰算法个体…

【字节跳动青训营】后端笔记整理-3 | Go语言工程实践之测试

**本文由博主本人整理自第六届字节跳动青训营(后端组),首发于稀土掘金:🔗Go语言工程实践之测试 | 青训营 目录 一、概述 1、回归测试 2、集成测试 3、单元测试 二、单元测试 1、流程 2、规则 3、单元测试的例…

AQS抽象同步队列核心原理

CLH自旋锁 JUC中显式锁基于AQS抽象队列同步器,而AQS是CLH锁的一个变种。队列头结点可以获得锁,其他节点排队等候。 在争夺锁激烈的情况下,为了减少CAS空自旋(CAS需要CPU进行内部通信保证缓存一致性造成流量过大引起总线风暴&…

【代码随想录day21】二叉搜索树中的众数

题目 给你一个含重复值的二叉搜索树(BST)的根节点 root ,找出并返回 BST 中的所有 众数(即,出现频率最高的元素)。 如果树中有不止一个众数,可以按 任意顺序 返回。 假定 BST 满足如下定义&am…

Git移除commit过的大文件

前言:在提交推送本地更改至仓库时,误将大文件给提交了,导致push时报错文件过大,因此需要将已经commit的大文件移除后再push 若已知要删除的文件或文件夹路径,则可以从第4步开始 1.对仓库进行gc操作 $ git gc 2.查询…

ThinkPHP 一对多关联

用一对多关联的前提 多的一方的数据库表有一的一方数据库表的外键。 举例,用户获取自己的所有文章 数据表结构如下 // 用户表 useruser_id - integer // 用户主键name - varchar // 用户名称// 文章表 articlearticle_id - integer // 文章主键title - varchar …

WSL2安装google chrome浏览器

一. 环境: Windows 11 Ubuntu-22.04 二. 安装google-chrome步骤(官方文档): 1. 创建文件夹:mkdir chrome 2. 进入目录:cd chrome/ 3. 下载chrome压缩包:sudo wget https://dl.google.com/linux/direct/go…

学习 NestJs 的第一步

安装 NestJS 的先决条件和安装 NestJS NodeJS 的版本需要大于等于 16。 安装 NestJS 的命令是&#xff1a;npm i -g nestjs/cli。 使用命令创建项目 使用 nest new <项目名称> 来创建项目&#xff0c;假如要开启 TS 的严格语法功能的话&#xff0c;可以把--strict 标…

【雕爷学编程】Arduino动手做(93)--- 0.96寸OLED液晶屏模块15

37款传感器与执行器的提法&#xff0c;在网络上广泛流传&#xff0c;其实Arduino能够兼容的传感器模块肯定是不止这37种的。鉴于本人手头积累了一些传感器和执行器模块&#xff0c;依照实践出真知&#xff08;一定要动手做&#xff09;的理念&#xff0c;以学习和交流为目的&am…

OSPF的拓展配置

OSPF的拓展配置 1.手工认证 --- 在OSPF数据包交互中&#xff0c;邻居之间的数据报中将携带认证口令&#xff0c;两边认证口令相同&#xff0c;则意味着身份合法 OSPF的手工认证总共分为三种&#xff1a; 1.接口认证 [r5-GigabitEthernet0/0/0]ospf authenticati…

GB/T 25000.51解读——软件产品的性能效率怎么测?

GB/T 25000.51-2016《软件产品质量要求和测试细则》是申请软件检测CNAS认可一定会用到的一部国家标准。在前面的文章中&#xff0c;我们为大家整体介绍了GB/T 25000.51-2016《软件产品质量要求和测试细则》国家标准的结构和所涵盖的内容以及对软件产品的八大质量特性中的功能性…