PyTorch~自定义数据读取

news2025/1/11 21:07:35

这次是PyTorch的自定义数据读取pipeline模板和相关trciks以及如何优化数据读取的pipeline等。

因为有torch也放人工智能模块了~

从PyTorch的数据对象类Dataset开始。Dataset在PyTorch中的模块位于utils.data下。

from torch.utils.data import Dataset

围绕Dataset对象分别从原始模板、torchvision的transforms模块、使用pandas来辅助读取、torch内置数据划分功能和DataLoader来展开阐述。

Dataset原始模板

PyTorch官方为我们提供了自定义数据读取的标准化代码代码模块,作为一个读取框架,我们这里称之为原始模板。其代码结构如下:

from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        
    def __getitem__(self, index):
        # stuff
        return (img, label)
        
    def __len__(self):
        # return examples size
        return count

根据这个标准化的代码模板,我们只需要根据自己的数据读取任务,分别往__init__()、__getitem__()和__len__()三个方法里添加读取逻辑即可。作为PyTorch范式下的数据读取以及为了后续的data loader,三个方法缺一不可。其中:

  • __init__()函数用于初始化数据读取逻辑,比如读取包含标签和图片地址的csv文件、定义transform组合等。

  • __getitem__()函数用来返回数据和标签。目的上是为了能够被后续的dataloader所调用。

  • __len__()函数则用于返回样本数量。

现在我们往这个框架里填几行代码来形成一个简单的数字案例。创建一个从1到100的数字例子:

from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self):
        self.samples = list(range(1, 101))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        return self.samples[idx]
        
if __name__ == '__main__':
    dataset = CustomDataset()
    print(len(dataset))
    print(dataset[50])
    print(dataset[1:100])

 添加torchvision.transforms

然后我们来看如何从内存中读取数据以及如何在读取过程中嵌入torchvision中的transforms功能。torchvision是一个独立于torch的关于数据、模型和一些图像增强操作的辅助库。主要包括datasets默认数据集模块、models经典模型模块、transforms图像增强模块以及utils模块等。在使用torch读取数据的时候,一般会搭配上transforms模块对数据进行一些处理和增强工作。

添加了tranforms之后的读取模块可以改写为:

from torch.utils.data import Dataset
from torchvision import transforms as T

class CustomDataset(Dataset):
    def __init__(self, ...):
        # stuff
        ...
        # compose the transforms methods
        self.transform = T.Compose([T.CenterCrop(100),
                                T.ToTensor()])
        
    def __getitem__(self, index):
        # stuff
        ...
        data = # Some data read from a file or image
        # execute the transform
        data = self.transform(data)  
        return (img, label)
        
    def __len__(self):
        # return examples size
        return count
        
if __name__ == '__main__':
    # Call the dataset
    custom_dataset = CustomDataset(...)

可以看到,我们使用了Compose方法来把各种数据处理方法聚合到一起进行定义数据转换方法。通常作为初始化方法放在__init__()函数下。我们以猫狗图像数据为例进行说明。

定义数据读取方法如下:

class DogCat(Dataset):    
    def __init__(self, root, transforms=None, train=True, val=False):
        """
        get images and execute transforms.
        """
        self.val = val
        imgs = [os.path.join(root, img) for img in os.listdir(root)]
        # train: Cats_Dogs/trainset/cat.1.jpg
        # val: Cats_Dogs/valset/cat.10004.jpg
        imgs = sorted(imgs, key=lambda x: x.split('.')[-2])
        self.imgs = imgs         
        if transforms is None:
            # normalize      
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
                                     std = [0.229, 0.224, 0.225])
            # trainset and valset have different data transform 
            # trainset need data augmentation but valset don't.
            # valset

            if self.val:
                self.transforms = T.Compose([
                    T.Resize(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                ])
            # trainset
            else:
                self.transforms = T.Compose([
                    T.Resize(256),
                    T.RandomResizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                ])
                       
    def __getitem__(self, index):
        """
        return data and label
        """
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
  
    def __len__(self):
        """
        return images size.
        """
        return len(self.imgs)

if __name__ == "__main__":
    train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)
    print(len(train_dataset))
    print(train_dataset[0])

因为这个数据集已经分好了训练集和验证集,所以在读取和transforms的时候需要进行区分。运行示例如下:

与pandas一起使用

很多时候数据的目录地址和标签都是通过csv文件给出的。如下所示:

此时在数据读取的pipeline中我们需要在__init__()方法中利用pandas把csv文件中包含的图片地址和标签融合进去。相应的数据读取pipeline模板可以改写为:

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = transforms.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    # Call dataset
    dataset =  CustomDatasetFromCSV('./labels.csv')

以mnist_label.csv文件为示例:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms as T
from PIL import Image
import os
import numpy as np
import pandas as pd

class CustomDatasetFromCSV(Dataset):
    def __init__(self, csv_path):
        """
        Args:
            csv_path (string): path to csv file            
            transform: pytorch transforms for transforms and tensor conversion
        """
        # Transforms
        self.to_tensor = T.ToTensor()
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # Second column is the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])
        # Third column is for an operation indicator
        self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
        # Calculate len
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        # Get image name from the pandas df
        single_image_name = self.image_arr[index]
        # Open image
        img_as_img = Image.open(single_image_name)
        # Check if there is an operation
        some_operation = self.operation_arr[index]
        # If there is an operation
        if some_operation:
            # Do some operation on image
            # ...
            # ...
            pass

        # Transform image to tensor
        img_as_tensor = self.to_tensor(img_as_img)
        # Get label of the image based on the cropped pandas column
        single_image_label = self.label_arr[index]
        return (img_as_tensor, single_image_label)

    def __len__(self):
        return self.data_len

if __name__ == "__main__":
    transform = T.Compose([T.ToTensor()])
    dataset = CustomDatasetFromCSV('./mnist_labels.csv')
    print(len(dataset))
    print(dataset[5])

运行示例如下:

训练集验证集划分

一般来说,为了模型训练的稳定,我们需要对数据划分训练集和验证集。torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。

以kaggle的花朵数据为例:  whaosoft aiot http://143ai.com

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
from torch.utils.data import random_split

transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.ToTensor()
 ])

dataset = ImageFolder('./flowers_photos', transform=transform)
print(dataset.class_to_idx)

trainset, valset = random_split(dataset, 
                [int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])

trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img, label)

valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)
for i, (img, label) in enumerate(trainloader):
    img, label = img.numpy(), label.numpy()
    print(img.shape, label)

这里使用了ImageFolder模块,可以直接读取各标签对应的文件夹,部分运行示例如下: 

使用DataLoader

dataset方法写好之后,我们还需要使用DataLoader将其逐个喂给模型。上一节的数据划分我们已经用到了DataLoader函数。从本质上来讲,DataLoader只是调用了__getitem__()方法并按批次返回数据和标签。使用方法如下:

from torch.utils.data import DataLoader
from torchvision import transforms as T

if __name__ == "__main__":
    # Define transforms
    transformations = T.Compose([T.ToTensor()])
    # Define custom dataset
    dataset = CustomDatasetFromCSV('./labels.csv')
    # Define data loader
    data_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)
    for images, labels in data_loader:
        # Feed the data to the model

以上就是PyTorch读取数据的Pipeline主要方法和流程。基于Dataset对象的基本框架不变,具体细节可自定义化调整。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/73045.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端入门必备基础

化繁为简 HTML5要的就是简单、避免不必要的复杂性。HTML5的口号是“简单至上,尽可能简化”。因此,HTML5做了以下改进: 以浏览器原生能力替代复杂的JavaScript代码。 新的简化的DOCTYPE。 新的简化的字符集声明。 简单而强大的HTML5API。…

[附源码]Python计算机毕业设计SSM基于云数据库的便民民宿租赁系统(程序+LW)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

《Linux运维实战:MongoDB数据库全量逻辑备份恢复(方案一)》

一、备份与恢复方案 mongodump是MongoDB官方提供的备份工具,它可以从MongoDB数据库读取数据,并生成BSON文件,mongodump适合用于备份和恢复数据量较小的MongoDB数据库, 不适用于大数据量备份。 默认情况下mongodump不获取local数据库里面的内容。mongodump仅备份数据库中的文档&…

回溯算法(1)组合

文章目录回溯算法理论77. 组合216. 组合总和17. 电话号码的组合回溯算法理论 回溯算法其实就是递归,只不过递归又分为递去和归来,其中归来便就是回溯。 为什么要使用回溯? 有些问题我们通过暴力解法也很难解决,比如说我们接下来…

C语言学习之路(高级篇)—— 变量和内存分布(上)

说明:该篇博客是博主一字一码编写的,实属不易,请尊重原创,谢谢大家! 数据类型 1) 数据类型概念 什么是数据类型?为什么需要数据类型? 数据类型是为了更好进行内存的管理,让编译器能确定分配…

04 | 云硬盘的使用方法

前期环境: Ubuntu 0 云硬盘类型 云硬盘类型包括: 高性能云硬盘通用型 SSD 云硬盘SSD 云硬盘增强型 SSD 云硬盘极速型 SSD 云硬盘,仅支持随存储增强型云服务器一同购买,不支持单独购买 1 创建云硬盘 1.1 创建方式 1.1.1 单个…

第二证券|连拉20CM涨停!防疫新概念股火了!恒生科技指数涨逾5%

周四上午,“新十条”发布后,由于A股商场已反弹一段时刻,两市股指今天早盘接连震动走势,港股在地产、科技、消费等板块带动下,体现更为强势。 A股上证指数早盘在3200点附近持续震动,光伏、化肥、物流、港口等…

JavaScript内置对象(内置对象、查文档(MDN)、Math对象、日期对象、数组对象、字符串对象)

目录 JavaScript内置对象 内置对象 查文档 MDN Math对象 Math概述 案例一:封装自己的对象 随机数方法 random() 案例一:猜数字游戏 日期对象 Date 概述 Date()方法的使用 获取日期的总的毫秒形式 案例一:倒计时效果 数组对象 …

DoltLab本地部署实践

目录引言Dolt是什么?如何本地部署使用DoltLab具体安装步骤安装期间FAQ写在最后其他相关资料引言 自从搞深度学习训练模型以来,一直有个问题困扰着我:训练所用数据集的管理。为什么说这是一个问题呢? 在读研时,我们依据…

ELK日志分析系统概述及部署

文章目录一、ELK日志分析系统1、概念2、完整日志系统基本特征3、使用ELK的原因4、ELK 的工作原理二、ELK日志分析系统集群部署的操作步骤环境准备:1、 ELK Elasticsearch 集群部署(在Node1、Node2节点上操作)1.1、更改主机名、配置域名解析、…

剑指 Offer 53 - I. 在排序数组中查找数字 I

摘要 剑指 Offer 53 - I. 在排序数组中查找数字 I 一、二分查找 1.1 二分查找的分析 由于数组已经排序,因此整个数组是单调递增的,我们可以利用二分法来加速查找的过程。 考虑 target在数组中出现的次数,其实我们要找的就是数组中「第一…

汇编语言ch2_2 汇编语言中的debug

使用debug 可以完成以下功能: 可以查看 和改变 CPU 中,寄存器的内容;可以查看 和改变内存中的内容;可以将内存中的 机器指令 翻译成汇编指令使用汇编指令 在 内存中 存入 机器指令执行机器指令 首先,启动 Debug,在DO…

实现数智内控,数据分析创造价值——辽宁烟草智能风险体检系统

近两年,烟草行业部分单位围绕中心任务,结合实际,守正创新,开展了许多研究探索。比如,在财务大数据价值挖掘、会计共享中心建设、财务风险预警系统建设等方面做了大量卓有成效的工作。在这样的背景下,辽宁烟…

DSPE-MAL 磷脂改性马来酰亚胺简介CAS1360858-99-6

DSPE-MAL二硬脂酰磷脂酰乙醇胺改性马来酰亚胺 中文名称:二硬脂酰磷脂酰乙醇胺改性马来酰亚胺 英文名称:DSPE-MAL CAS:1235864-97-7 分子式:C48H86N2NaO11P 分子量:921.16700 外观:白色粉末 DSPE-MAL二…

2022icpc 济南站 持续补题

链接:Dashboard - 2022 International Collegiate Programming Contest, Jinan Site - Codeforces 签到题:k K. Stack Sort You are given a permutation with nn numbers, a1,a2,…,an(1≤ai≤n,ai≠aj when i≠j). You want to sort these numbers …

WY易盾cb、fp逆向分析

内容仅供参考学习 欢迎朋友们V一起交流: zcxl7_7 目标 网址:案例地址 这个好像还没改版,我看官网体验那边已经进行了混淆 分析 这个进行的请求很乱,我就不说怎么找的了,到时候越听越乱。一共有2个请求很重要 …

笔试题之编写SQL按要求查询用户阅读行为数据

紧张源于恐惧,恐惧源于未知。 文章目录前言一、SQL题目二、当时作答结果三、复盘(一)建表并自定义插入数据(二)正确解答(三)答错原因分析总结前言 分享本人一次失败的笔试经历,供各…

plink中的BGEN格式的数据如何用

这里,介绍一下BGEN格式的数据,他的文件格式是这样的:a.bgen,这是一个新的数据格式,目前应用不如plink的二进制文件:.bim,.bed,.fam。这里介绍一下如何相互转换。 1. bgen格式介绍 现代遗传关联研究通常使…

[附源码]计算机毕业设计JAVA中小企业人事管理系统

[附源码]计算机毕业设计JAVA中小企业人事管理系统 项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM my…

HMS Core 6.8.0版本发布公告

分析服务 ◆ 游戏行业新增“区服分析”埋点模板及分析报告,支持开发者分服务器查看用户付费、留存等指标,可进一步评估不同服务器的玩家质量; ◆ 新增营销活动报告,可查看广告任务带来的曝光、点击相关信息,让营销推…