Pytorch中utils.data 与torchvision简介

news2025/1/24 10:51:23

Pytorch中utils.data 与torchvision简介

  • 1 数据处理工具概述
  • 2 utils.data简介
  • 3 torchvision简介
    • 3.1 transforms
    • 3.2 ImageFolder

1 数据处理工具概述

Pytorch涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相互关系如下图所示,主要使用torch.utils.data 与 torchvision:
在这里插入图片描述

torch.utils.data工具包,它包括以下三个类:
(1)Dataset:是一个抽象类,其它数据集需要继承这个类,并且覆写其中的两个方法(getitemlen)。
(2)DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并提供并行加速等功能。
(3)random_split:把数据集随机拆分为给定长度的非重叠新数据集。
(4)*sampler:多种采样函数。

可视化处理工具torchvision:Pytorch的一个视觉处理工具包,独立于Pytorch,需要另外安装,使用pip或conda安装即可,包含四个类:

(1)datasets:提供常用的数据集加载,设计上都是继承torch.utils.data.Dataset,主要包括MMIST、CIFAR10/100、ImageNet、COCO等。
(2)models:提供深度学习中各种经典的网络结构以及训练好的模型(如果选择pretrained=True),包括AlexNet, VGG系列、ResNet系列、Inception系列等。
(3)transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作。
(4)utils:含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是save_img,它能将Tensor保存成图片。

2 utils.data简介

utils.data包括DatasetDataLoader

  • torch.utils.data.Dataset:为抽象类。自定义数据集需要继承这个类,并实现两个函数。一个是__len__,另一个是__getitem__,前者提供数据的大小(size),后者通过给定索引获取数据和标签。

  • 由于__getitem__一次只能获取一个数据,所以通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。

下面通过举例,来比较Dataset 和DataLoader

  • 1,导入相关模块
import torch
from torch.utils import data
import numpy as np
  • 2,定义获取数据集的类,该类继承基类Dataset,自定义一个数据集及对应标签。
class TestDataset(data.Dataset):#继承Dataset
    def __init__(self):
        self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#一些由2维向量表示的数据集
        self.Label=np.asarray([0,1,0,1,2])#这是数据集对应的标签
 
    def __getitem__(self, index):
        #把numpy转换为Tensor
        txt=torch.from_numpy(self.Data[index])
        label=torch.tensor(self.Label[index])
        return txt,label 
 
    def __len__(self):
        return len(self.Data)
  • 3,获取数据集中数据
Test=TestDataset()
print(Test[2])  #相当于调用__getitem__(2)
print(Test.__len__())
 
#輸出:
#(tensor([2, 1]), tensor(0))
#5

上面使用Dataset的方式,每次只返回一个样本。如果希望批处理,同时还要shuffle和并行加速等操作,可选择DataLoader

data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
)

主要参数说明:
 dataset: 加载的数据集;
 batch_size: 批大小;
 shuffle:是否将数据打乱;
 sampler:样本抽样
 num_workers:使用多进程加载的进程数,0代表不使用多进程;
 collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可;
 pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些;
 drop_last:dataset 中的数据个数可能不是 batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。

test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2)
for i,traindata in enumerate(test_loader):
    print('i:',i)
    Data,Label=traindata
    print('data:',Data)
    print('Label:',Label)

从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,如对它进行循环操作。不过它不是迭代器,我们可以通过iter命令转换为迭代器。

一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,不同目录代表不同类别(这种情况比较普遍),使用data.Dataset来处理就不很方便

不过,可以使用Pytorch另一种可视化数据处理工具(即torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数。

3 torchvision简介

torchvision有4个功能模块,

  • model:
  • datasets:
  • transforms:如何使用transforms对源数据进行预处理、增强等
  • utils:

3.1 transforms

transforms提供了对PIL Image对象和Tensor对象的常用操作

(1)对PIL Image的常见操作如下:
 Scale/Resize: 调整尺寸,长宽比保持不变;
 CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop和RandomCrop在crop时是固定size,RandomResizedCrop则是random size的crop;
 Pad: 填充;
 ToTensor: 把一个取值范围是[0,255]的PIL.Image 转换成 Tensor。形状为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloatTensor。
 RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5;
 RandomVerticalFlip: 图像随机垂直翻转;
 ColorJitter: 修改亮度、对比度和饱和度。

(2)对Tensor的常见操作如下:
 Normalize: 标准化,即减均值,除以标准差;
 ToPILImage:将Tensor转为PIL Image。

如果要对数据集进行多个操作,可通过Compose将这些操作像管道一样拼接起来,类似于nn.Sequential。以下为示例代码

transforms.Compose([
    #将给定的 PIL.Image 进行中心切割,得到给定的 size,
    #size 可以是 tuple,(target_height, target_width)。
    #size 也可以是一个 Integer,在这种情况下,切出来的图片形状是正方形。            
    transforms.CenterCrop(10),
    #切割中心点的位置随机选取
    transforms.RandomCrop(20, padding=0),
    #把一个取值范围是 [0, 255] 的 PIL.Image 或者 shape 为 (H, W, C) 的 numpy.ndarray,
    #转换为形状为 (C, H, W),取值范围是 [0, 1] 的 torch.FloatTensor
    transforms.ToTensor(),
    #规范化到[-1,1]
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])

3.2 ImageFolder

当文件依据标签处于不同文件下时,如:
在这里插入图片描述

可以利用 torchvision.datasets.ImageFolder 来直接构造出 dataset,代码如下:

loader = datasets.ImageFolder(path)
loader = data.DataLoader(dataset)

ImageFolder 会将目录中的文件夹名自动转化成序列,那么DataLoader载入时,标签自动就是整数序列了。
下面我们利用ImageFolder读取不同目录下图片数据,然后使用transorms进行图像预处理,预处理有多个,我们用compose把这些操作拼接在一起。然后使用DataLoader加载。

对处理后的数据用torchvision.utils中的save_image保存为一个png格式文件,然后用Image.open打开该png文件,详细代码如下:

from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt 
%matplotlib inline
 
my_trans=transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/torchvision_data', transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)
                                            
for i_batch, img in enumerate(train_loader):
    if i_batch == 0:
        print(img[1])
        fig = plt.figure()
        grid = utils.make_grid(img[0])
        plt.imshow(grid.numpy().transpose((1, 2, 0)))
        plt.show()
        utils.save_image(grid,'test01.png')
    break

其他功能模块待更新!
参考:python深度学习-基于pytorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/396898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文献阅读(48)—— 长序列time-series预测【Informer】

文献阅读(48)—— 长序列time-series预测【Informer】 文章目录文献阅读(48)—— 长序列time-series预测【Informer】先验知识/知识拓展文章结构文章方法1. 文章核心网络结构(1) 传统意义上的transformer应…

数据结构4——线性表3:线性表的链式结构

基本概念 ​ 链式存储结构用一组物理位置任意的存储单元来存放线性表的数据元素。 ​ 这组存储单元既可以是连续的又可以是不连续的甚至是零散分布在任意位置上的。所以链表中元素的逻辑次序和物理次序不一定相同。而正是因为这一点,所以我们要利用别的方法将这些…

Kafka消息中间件(Kafka与MQTT区别)

文章目录KafkaKafka重要原理Topic 主题Partition 分区Producer 生产者Consumer 消费者Broker 中间件Offset 偏移量Kafka与mqtt区别Kafka Kafka是一个分布式流处理平台,它可以快速地处理大量的数据流。Kafka的核心原理是基于发布/订阅模式的消息队列。Kafka允许多个…

C++基础——C++面向对象之重载与多态基础总结(函数重载、运算符重载、多态的使用)

【系列专栏】:博主结合工作实践输出的,解决实际问题的专栏,朋友们看过来! 《QT开发实战》 《嵌入式通用开发实战》 《从0到1学习嵌入式Linux开发》 《Android开发实战》 《实用硬件方案设计》 长期持续带来更多案例与技术文章分享…

MySQL8.0.16存储过程比5.7.22性能大幅下降

MySQL8.0.16存储过程比5.7.22性能大幅下降 1、背景 从5.7.22迁移数据库到8.0.16,发现存储过程执行性能大幅下降。原来在5版本上执行只需要3-5秒,到8版本上居然要达到上万秒。 5版本: call Calculation_Week() OK 时间: 3.122s 8版本&#x…

移动通信(16)信号检测

常见的信号检测算法一般包括以下几类检测算法:最优、线性和非线性。最优检测算法:最大似然算法线性检测算法:迫零检测算法和最小均方误差检测算法非线性检测算法:串行干扰消除检测算法球形译码检测算法属于一种次优检测算法&#…

凤凰游攻略

凤凰游攻略1 装备📦1.1 证件1.2 日常用品1.3 药品1.4 衣物1.5 洗漱用品2 交通🚗3 住宿🏠4 美食🍕5 拍照📷5.1 租苗族服5.1.1 单租服装5.1.2 服装化妆5.2 一条龙旅拍6 路线🗺️景点🏙️7 注意⚠️…

计算机网络的166个概念你知道几个 第十二部分

计算机网络安全安全通信的四大要素:机密性、保温完整性、端点鉴别和运行安全性。机密性:报文需要在一定程度上进行加密,用来防止窃听者截取报文。报文完整性:在报文传输过程中,需要确保报文的内容不会发生改变。端点鉴…

java StringBuilder 和 StringBuffer 万字详解(深度讲解)

StringBuffer类介绍和溯源StringBuffer类常用构造器和常用方法StringBuffer类 VS String类(重要)二者的本质区别(含内存图解)二者的相互转化StringBuilder类介绍和溯源StringBuilder类常用构造器和常用方法String类,St…

0308java基础-注解,反射

一,注解 1.什么是注解: Annotation是从jdk5.0开始引入的新技术作用: 不是程序本身,可以对程序作出解释可以被其他程序读取格式: 以注释名在代码中存在,还可以添加一些参数值SuppressWarnings(value"…

0103 MySQL06

1.事务 1.一个事务其实就是一个完整的业务逻辑 如:转账,从A账户向B账户转账10000,将A账户的钱减去10000(update),将B账户的钱加上10000(update),这就是一个完整的业务逻…

【Mybatis】| 如何创建MyBatis的工具类

目录🌟更多专栏请点击👇一、前言二、实现过程1. 创建一个ThreadLocal对象2. 初始化SqlSessionFactory3. 获取并存储sqlSession对象4. 关闭sqlSession对象三、 总代码🌟更多专栏请点击👇 专栏名字🔥Elasticsearch专栏e…

向2022年度商界木兰上榜女性致敬!

目录 信息来源: 2022年度商界木兰名单 简介 评选标准 动态 榜单 为你心中的2023商界女神投上一票 信息来源: 2022年度商界木兰榜公布 华为孟晚舟获商界木兰最高分 - 脉脉 【最具影响力女性】历届商界木兰榜单 中国最具影响力的30位商界女性名单…

基于Vue+Vue-cli+webpack搭建渐进式高可维护性前端实战项目

本文是专栏《手把手带你做一套毕业设计毕业设计》的实战第一篇,将从Vue脚手架安装开始,逐步带你搭建起一套管理系统所需的架构。当然,在默认安装完成之后,会对文件目录进行初步的细化拆分,以便后续功能迭代和维护所用。…

经典100道mysql的面试题

100道mysql的面试题 目录100道mysql的面试题1. MySQL 索引使用有哪些注意事项呢?索引哪些情况会失效索引不适合哪些场景索引的一些潜规则2. MySQL 遇到过死锁问题吗,你是如何解决的?3. 日常工作中你是怎么优化SQL的?4. 说说分库与…

字体反爬慢慢总结破解方式

什么是字体反爬 网页开发者自己创造一种字体,因为在字体中每个汉字都有其代号,那么以后再网页中不会直接显示这个文字的效果。而是显示其代号,因此即使获取了网页的文本内容。也只是获取到文字的代号,而不是文字本身。 简单来说&…

逻辑优化基础-shannon decomposition

1. 简介 在逻辑综合中,香农分解(Shannon decomposition)是一种常用的布尔函数分解方法。它将一个布尔函数分解为两个子函数的和,其中每个子函数包含一个布尔变量的取反和非取反的部分。 具体来说,假设对于一个布尔函…

Mysql 索引特点

承接上文Mysql Server原理简介聚簇索引、二级索引、联合索引分别具备什么样的特点?聚簇索引数据跟索引放在一起的叫聚簇索引;数据和索引分开存储的叫非聚簇索引;innodb存储引擎,数据和文件都放在ibd文件中,实际的数据是…

在教学中常被问到的几个vue3.x与typescript的问题,统一解答

在教学当中,学生在学习vue3.x时,常常会问到typescript和vue3.x之间的关系,感觉这两个技术总是绑在一起的,下面老赵来统一解答一下: 那学vue3.x,为什么要求也要掌握typescript Vue 3.x是一个使用TypeScript编…

「ML 实践篇」机器学习项目落地

文章目录1. 项目分析1. 框架问题2. 性能指标2. 获取数据1. 准备工作区2. 下载数据3. 查看数据4. 创建测试集3. 数据探索1. 地理位置可视化2. 寻找相关性3. 组合属性4. 数据准备1. 数据清理2. Scikit-Learn 的设计3. 处理文本、分类属性4. 自定义转换器5. 特征缩放6. 流水线5. 选…