深入浅出PyTorch数据读取机制

news2025/1/9 5:24:21

熟悉深度学习的小伙伴一定都知道:深度学习模型训练主要由数据、模型、损失函数、优化器以及迭代训练五个模块组成。如下图所示,Pytorch数据读取机制则是数据模块中的主要分支。

在这里插入图片描述

Pytorch数据读取是通过​​Dataset​​​+​​Dataloader​​的方式完成。其中,

  • DataSet:定义数据集。将原始数据样本及对应标签映射到Dataset,便于后续通过index读取数据。同时,还可以在Dataset中进行数据格式变换、数据增强等预处理操作。
  • DataLoader:迭代读取数据集。将数据样本进行分批次Batch、打乱顺序Shuffle等处理,便于训练时迭代读取数据。

Dataset

Dataset用于解决数据从哪里读取以及如何读取的问题。 Pytorch给定的Dataset是一个抽象类,所有自定义的数据集都要继承Dataset,并重写**init()、getitem()和__len__()**类方法,以供DataLoader类直接调用。

  • init:数据集初始化。
  • getitem:定义指定索引如何获取样本数据,最终返回index对应的样本对{样本数据x:标签y}。
  • len():数据集的样本数。

下面是笔者以cifar10数据集为例实现Dataset自定义数据集的代码样例。

from torch.utils.data import Dataset
from PIL import Image
import os

class Mydata(Dataset):
    """
    步骤一:继承 torch.utils.data.Dataset 类
    """
    def __init__(self,data_dir,label_dir):
        """
        步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中
        """
        self.data_dir = data_dir
        self.label_dir = label_dir
        # 用join把路径拼接一起可以避免一些因“/”引发的错误
        self.path = os.path.join(self.data_dir,self.label_dir)
        # 将该路径下的所有文件变成一个列表
        self.img_path = os.listdir(self.path)

    def __getitem__(self,idx)
        """
        步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签)
        """
        # 根据index(idx),从列表中取出图片
        # img_path列表里每个元素就是对应图片文件名
        img_name = self.img_path[idx]
        # 获得对应图片路径
        img_item_path = os.path.join(self.data_dir,self.label_dir,img_name)
        # 使用PIL库下Image工具,打开对应路径图片
        img = Image.open(img_item_path)
        label = self.label_dir
        # 返回图片和对应标签
        return img,label

    def __len__(self):
        """
        步骤四:实现 __len__ 函数,返回数据集的样本总数
        """
        return len(self.img_path)

# data_dir,label_dir可自定义数据集目录
train_custom_dataset = MyData(data_dir,label_dir)
test_custom_dataset = MyData(data_dir,label_dir)

DataLoader

在实际项目中,当数据量很大,考虑到内存有限、I/O速度等问题,训练中不可能一次性将所有数据加载到内存或者只用一个进行加载数据,此时就需要的是多进程、迭代加载,Dataloader便应运而生。

DataLoader是一个可迭代的数据装载器,组合了数据集和采样器,并在给定数据集上提供可迭代对象。可以完成对数据集中多个对象的集成。

Pytorch的数据读取机制中DataLoader模块包括Sampler和Dataset两个子模块,其中Sampler模块生成索引index;Dataset模块是根据索引读取数据。DataLoader读取数据流程如下图所示。

在这里插入图片描述

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UZqgYimv-1684309723395)(imgs/230424183501.png)]

  • DataLoader:进入DataLoader模块。
  • DataloaderIter:进入__iter__函数判断是否采用多进程,并进入相应的读取机制。
  • Sampler:通过采样,挑选每个Batchsize该读取的数据,并返回这些数据的index。
  • index:一个batchsize数据的索引。
  • DatasetFetcher:获取index对应的数据。
  • Dataset:调用dataset[idx]获取相应数据,并拼接成list。
  • getitem:Dataset的核心,用索引获取数据。
  • img,label:读取到的数据。
  • collate_fn:将读取的数据从list转为batch形式。
  • Batch Data:batch形式数据,第一个元素是图像,第二个元素是标签。

Pytorch中DataLoader类定义如下:

class torch.utils.data.DataLoader(
     """
     构建可迭代的数据装载器,训练时,每一个for循环,每一次迭代,
     从DataLoader中获取一个batch_size大小的数据
     """
     dataset,
     batch_size=1,
     shuffle=False,
     sampler=None,
     batch_sampler=None,
     num_workers=0,
     collate_fn=None,
     pin_memory=False,
     drop_last=False,
)
  • dataset:需要加载的数据集,Dataset对象。
  • batch_size:每批次读取样本数。例如batch_size=16表示每批次读取16个样本。
  • shuffle:每个epoch是否乱序。shuffle=True表示在取数据时打乱样本顺序,以减少过拟合发生的可能。
  • sampler:索引index。
  • batch_sampler:将返回一个索引的sampler进行包装,按照设定的batch_size返回一组索引。
  • num_workers:同步/异步读取数据。num_workers=0表示数据加载是同步的,在主进程中完成。num_workers的值设为大于0时,即开启多进程方式异步加载数据,可提升数据读取速度。
  • pin_memory:是否将数据拷贝到拷贝到临时缓冲区。
  • collate_fn:将多个样本组合在一起变成一个mini-batch,不指定该函数的话会调用Pytorch内部默认的函数。
  • drop_last:丢弃不完整的批次样本,drop_last=True表示当数据集样本数不能被batch_size整除时,则丢弃最后一个不完整的batch样本。

补充说明

Epoch:所有训练样本都已输入到模型中,称为一个epoch
Iteration:一批样本(batch_size)输入到模型中,称为一个Iteration。
Batchsize:一批样本的大小,称为Batchsize。用于决定一个epoch有多少个Iteration。

代码实现示例如下。

import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)

# 将数据集转换为torch可识别的类型
torch_dataset = Data.TensorDataset(x, y)

loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        print('epoch', epoch,
              '| step:', step,
              '| batch_x', batch_x.numpy(),
              '| batch_y:', batch_y.numpy())

在这里插入图片描述

通过上述方法即可初始化一个数据读取器loader,用于加载训练数据集torch_dataset。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/536581.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SOME/IP中间件通信流程

本文根据文章《CAPL如何实现SOME/IP协议通信:SomeIP_IL.dll函数篇(超两万字详解)》内容,抽取总结出SOME/IP通信流程,正不正确的另说,目的是为了加深对SOME/IP中间件的理解。 首先,不管是消费方consumer,还是提供方provider,都有自己的someip中间件。本质上,它是一个…

vue3 cesium datav 可视化大屏

目录 0. 预览效果 1. 代码库包 2. 技术点 3. 一些注意事项(配置参数) 4. 相关代码详情 0. 预览效果 包含的功能: ① 地球按照一定速度自转 ② 修改加载的geojson面样式 ③ 添加 文字 标注! 1. 代码库包 直接采用vue-cli5 创建…

MySql从入门到精通

MySql介绍 MySQL 是最流行的关系型数据库管理系统,在 WEB 应用方面 MySQL 是最好的 RDBMS(Relational Database Management System:关系数据库管理系统)应用软件之一。 什么是数据库 数据库(Database)是按照数据结构来组织、存储…

oracle 闪回恢复

oracle 闪回恢复 闪回恢复区主要通过3个初始化参数来设置和管理: db_recovery_file_dest:指定闪回恢复区的位置 db_recovery_file_dest_size:指定闪回恢复区的可用空间大小 db_flashback_retention_target:指定数据库可以回退的时…

年近30 ,无情被辞,想给划水的兄弟提个醒

前几天,一个认识了好几年在大厂工作的程序员朋友,年近30了,却被大厂以“人员优化”的名义无情被辞,据他说,有一个月散伙饭都吃了好几顿…… 在很多企业,都有KPI考核,然后在此基础上还会弄个“末…

讲的太好了!!!————————Idea中的VM Options、Program Arguments、Environment Variable全解析

参数使用方式示例代码获取方式VM Options必须以 -D 、 -X 、 -XX 开头,每个参数用空格隔开 ,使用最多的就是 -Dkeyvalue-Dvm.keyVmKey -Dvm.key2VmKey2String key System.getProperty(“vm.key”); Program Arguments为我们传入main方法的字符串数组arg…

10-03 单元化架构设计

设计原则 透明 对开发者透明 在做实现时,不依赖于单元划分和部署对组件透明 在组件运行时,不感知其承载单元对数据透明 数据库并不知道为哪个单元提供服务 业务可分片 系统业务复杂度足够高系统可以按照某一维度进行切分系统数据必须可以被区分 业务…

【网络】交换机基本原理与配置

目录 🍁交换机工作原理 🍁交换机接口的双工模式 🍁交换机命令行模式 🍁交换机常见命令 🧧帮助命令 🧧常用命令介绍 🍁交换机的基本配置 🧧配置接口的双工模式及速率 🦐博…

knife4j生产环境资源屏蔽

问题描述 knife4j是目前比较主流的自动API文档生成工具,在生产环境使用的过程中,我们一般会屏蔽或者去除Swagger的文档口径,防止接口信息泄露,保证系统安全。 但是最近在开发过程中使用knife4j-spring-boot-starter 3.0.2过程中&…

dolphinscheduler使用impala shell执行sql

目录 一、背景 二、方法 1.impala shell -f 文件名 2.impala shell -q sql 一、背景 因为dolphinscheduler工具sql组件不支持impala数据源,只能折衷方法通过shell来执行impala sql。 二、方法 1.impala shell -f 文件名 操作步骤: 1).【资源中心】…

受邀参加【第七届】中国客户服务节

在AI浪潮的推动下,客户服务“智能化”是企业高质量发展的重要途径之一,目前人工智能、大数据、云计算等技术已广泛应用于全行业的客户服务场景中,一个全面、完善、稳定的智能通讯服务平台可助力实现企业智能化应用转型和升级。 讯鸿网络作为国…

嘉立创EDA原理图封装画错了怎么办

摘要:本文以贴片电阻封装由1206修改为0805为例,介绍一下封装修改的一种方法。 1.问题描述 设计原理图的时候,误将封装设计成为1206了,现在想把它改为0805封装。 2.修改封装的步骤 首先在原理图中,修改对应的电阻器件…

Flutter 3.10 适配之单例 Window 弃用,一起来了解 View.of 和 PlatformDispatcher

Flutter 3.10 发布之后,大家可能注意到,在它的 release note 里提了一句: Window singleton 相关将被弃用,并且这个改动是为了支持未来多窗口的相关实现。 所以这是一个为了支持多窗口的相关改进,多窗口更多是在 PC 场…

统计学习方法:序贯概率比检验SPRT

Sequential Probability Ratio Test 应用:制造过程中的质量控制和医学试验中的异常检测 1.theory/principal 区别(vs固定样本检验):在固定样本检验中,一定数量的观察结果被用来从两个或多个备选方案中选择一个假设。而SPRT则是一次检查一个…

AI“应用商店”来了!OpenAI首批70个ChatGPT Plugin最全梳理

OpenAI放出大招,本周将向所有ChatGPT Plus用户开放联网功能和众多插件本周将向所有ChatGPT Plus用户开放联网功能和众多插件,允许ChatGPT访问互联网并使用70个第三方插件。 本批第三方插件能够全方位覆盖衣食住行、社交、工作以及学习等日常所需&#x…

Electron自定义窗口

Electron标题栏隐藏和自定义 Electron应用自定义标题栏样式 标题栏样式允许隐藏浏览器窗口的大部分色彩,同时保持系统原生窗口控件完整无损,并可以在 BrowserWindow 的构造器中使用 titleBarStyle 选项来配置。 应用 hidden 标题栏样式的结果是隐藏标…

无线充+台灯专用PD诱骗芯片LDR6328S

近几年,日常生活中到处可以看到消费者使用支持Type-c接口的电子产品,如手机,笔记本,筋膜枪,蓝牙音箱等等。例如,像筋膜枪,蓝牙音箱,无人机,小风扇。 无线充台灯方案&…

librosa语音信号处理

librosa是一个非常强大的python语音信号处理的第三方库,本文参考的是librosa的官方文档,本文主要总结了一些重要,对我来说非常常用的功能。学会librosa后再也不用python去实现那些复杂的算法了,只需要一句语句就能轻松实现。 先总…

数字化时代,初创公司如何建设业财一体化

业财一体的关键是构建“业务活动跟财务活动之间的线上化链接”,财务可以通过线上支撑业务,业务活动数据可以通过线上高时效触达财务;从业务数据到财务数据,除了需要运营系统的支撑还需要会计引擎的实现,会计引擎将业务…

优秀的开发者,如何借助免费低代码平台实现数据采集?

采集和管理数据, 从未如此简单自然 一款免费的零代码产品‘敲敲云’,可以帮助每个人轻松创建表单,自由收集问卷样本、活动参与者名单、客户数据,原本几天的工作在 1 个小时内轻松搞定。 表单编辑器,让你和数据专家一样…