Pytorch深度学习实战3-7:详解数据加载DataLoader与模型处理

news2024/12/28 21:53:00

目录

  • 1 数据集Dataset
  • 2 数据加载DataLoader
  • 3 常用预处理方法
  • 4 模型处理
  • 5 实例:MNIST数据集处理

1 数据集Dataset

Dataset类是Pytorch中图像数据集操作的核心类,Pytorch中所有数据集加载类都继承自Dataset父类。当我们自定义数据集处理时,必须实现Dataset类中的三个接口:

  • 初始化
    def __init__(self)
    
    构造函数,定义一些数据集的公有属性,如数据集下载地址、名称等
  • 数据集大小
    def __len__(self)
    
    返回数据集大小,不同的数据集有不同的衡量数据量的方式
  • 数据集索引
    def __getitem__(self, index):
    
    支持数据集索引功能,以实现形如dataset[i]得到数据集中的第i + 1个数据的功能。__getitem__是后期迭代数据时执行的具体函数,其返回值决定了循环变量,例如
    class data(Dataset)
    	...
        def __getitem__(self, idx: int):
            if self.transforms:
                img = self.transforms(img)
            return img, label			# 返回的值即为后续迭代的循环变量
    
    for images, labels in dataLoader:
    	...
    

2 数据加载DataLoader

为什么有了数据集Dataset还需要数据加载器DataLoader呢?原因在于神经网络需要进一步借助DataLoader对数据进行划分,也就是我们常说的batch,此外DataLoader还实现了打乱数据集、多线程等操作。

DataLoader本质是一个可迭代对象,可以使用形如

for inputs, labels in dataloaders

进行可迭代对象的访问。

我们一般不需要去实现DataLoader的接口,只需要在构造函数中指定相应的参数即可,比如常见的batch_sizeshuffle等参数。

下面这张图非常好地说明了DatasetDataLoader的关系

在这里插入图片描述

接下来总结数据构造的三步法

  1. 继承Dataset对象,并实现__len__()__getitem__()魔法方法,该步骤的主要目的在于将文件形式的数据集处理为模型可用的标准数据格式,并加载到内存中;
  2. DataLoader对象封装Dataset,使其成为可迭代对象;
  3. 遍历DataLoader对象以将数据加载到模型中进行训练。

3 常用预处理方法

在数据集Dataset__getitem__()中利用torchvision.transforms进行数据预处理与变换

常见的数据预处理变换方法总结如下表

序号变换含义
1RandomCrop(size, ...)对输入图像依据给定size随机裁剪
2CenterCrop(size, ...)对输入图像依据给定size从中心裁剪
3RandomResizedCrop(size, ...)对输入图像随机长宽比裁剪,再放缩到给定size
4FiveCrop(size, ...)对输入图像进行上下左右及中心裁剪,返回五张图像(size)组成的四维张量
5TenCrop(size, vertical_flip=False)对输入图像进行上下左右及中心裁剪,再全部翻转(水平或垂直),返回十张图像(size)组成的四维张量
6RandomHorizontalFlip(p=0.5)对输入图像按概率p随机进行水平翻转
7RandomVerticalFlip(p=0.5)对输入图像按概率p随机进行垂直翻转
8RandomRotation(degree, ...)对输入图像在degree内随机旋转某角度
9Resize(size, ...)对输入图像重置分辨率
10Normalize(mean, std)对输入图像各通道进行标准化
11ToTensor()将输入图像或ndarray 转换为tensor并归一化
12Pad(padding, fill=0, padding_mode=‘constant’)对输入图像进行填充
13ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)对输入图像修改亮度、对比度、饱和度、色度等
14Grayscale(num_output_channels=1)对输入图像转灰度
15LinearTransformation(matrix)对输入图像进行线性变换
16RandomAffine(...)对输入图像进行仿射变换
17RandomGrayscale(p=0.1)对输入图像按概率p随机转灰度
18ToPILImage(mode=None)对输入图像转PIL格式图像
19RandomOrder()随机打乱transforms操作顺序

4 模型处理

考虑以下场景:

网络的部分层级结构已经收敛、无需调整;大型复杂网络需要微调(Fine-tune)某些结构或参数;希望基于已训练好的模型进行改善或其他研究工作。

这些场景下重新通过数据集训练整个神经网络并无必要,甚至会使模型不稳定,因此引入预训练(pretrained)。Pytorch允许用户保存已训练好的模型,或加载其他模型,避免往复的无谓重训练,其中模型参数文件以.pth为后缀

# 保存已训练模型
torch.save(model.state_dict(), path)
# 加载预训练模型
model.load_state_dict(torch.load(path), device)

通过设置模型某些层可学习参数的requires_grad属性为False即可固定这部分参数不被后续学习过程影响。深度学习框架应用优势之一在于预设了对GPU的支持,大大提高模型处理与训练的效率。Pytorch中通过mode.to(device)方法将模型部署到指定设备上(CPU/GPU),范式如下:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

工程上也常使用torch.nn.DataParallel(model, devices)来处理多GPU并行运算,其原理是:首先将模型加载到主GPU上,再将模型从主GPU产生若干副本到其余GPU,随后将一个batch中的数据按维度划分为不同的子任务给各GPU进行前向传播,得到的损失会被累积到主GPU上并由主GPU反向传播更新参数,最后将更新参数拷贝到其余GPU以开始下一轮训练。

5 实例:MNIST数据集处理

下面给出了处理MNIST手写数据集的完整代码,可以用于加深对数据处理流程的理解

from abc import abstractmethod
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import Dataset
from PIL import Image

class mnistData(Dataset):
    '''
    * @breif: MNIST数据集抽象接口
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''    
    def __init__(self, dataPath: str, transforms=None) -> None:
        super().__init__()
        self.dataPath = dataPath
        self.transforms = transforms
        self.data, self.label = [], []

    def __len__(self) -> int:
        return len(self.label)

    def __getitem__(self, idx: int):
        img = self.data[idx]
        if self.transforms:
            img = self.transforms(img)
        return img, self.label[idx]

    @abstractmethod
    def plot(self, index: int) -> None:
        pass

    @abstractmethod
    def load(self) -> list:
        pass

    def plotData(self, index: int, info: str=None) -> None:
        '''
        * @breif: 可视化训练数据
        * @param[in]: index -> 数据集索引
        * @param[in]: info -> 备注信息
        * @retval: None
        '''
        print(info, " --index:", index, "--label:", self.label[index])  if info else \
        print(" --index:", index, "--label:", self.label[index])          
        img = Image.fromarray(np.uint8(self.data[index]))
        img.show()

    def loadData(self, train: bool) -> list:
        '''
        * @breif: 下载与加载数据集
        * @param[in]: train -> 是否为训练集
        * @retval: 数据与标签列表
        '''    
        # 如果指定目录下不存在数据集则下载
        dataSet   = mnist.MNIST(self.dataPath, train=train, download=True)
        # 初始化数据与标签
        data  = [ i[0] for i in dataSet ]
        label = [ i[1] for i in dataSet ]
        return data, label

class mnistTrainData(mnistData):
    '''
    * @breif: MNIST训练集
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''    
    def __init__(self, dataPath: str, transforms=None) -> None:
        super().__init__(dataPath, transforms=transforms)
        self.data, self.label = self.load()

    def plot(self, index: int) -> None:
        self.plotData(index, "trainSet data")

    def load(self) -> list:
        return self.loadData(train=True)


class mnistTestData(mnistData):
    '''
    * @breif: MNIST测试集
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''    
    def __init__(self, dataPath: str, transforms=None) -> None:
        super().__init__(dataPath, transforms=transforms)
        self.data, self.label = self.load()

    def plot(self, index: int) -> None:
        self.plotData(index, "testSet data")

    def load(self) -> list:
        return self.loadData(train=False)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/424595.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从TOP25榜单,看半导体之变

据SIA报告显示,2022年全球半导体销售额创历史新高达到5740亿美元。尽管2022年下半年,半导体市场出现了周期性的低迷,但其全年的销售额相较2021年增长了3.3%。 近日,市调机构Gartner发布了全球以及中国大陆TOP25名半导体厂商的排名…

js数组API的时间复杂度大全

一句话总结: 数组为连续且有序的数据结构, 所以若根据下标查找则很快,index[i]一步到位就可实现查询,若遍历查找则很慢(相对而言)。而插入和删除,除了数组末尾的增删很快,其它处则很慢,因为若数组某处要插入…

【服务器数据恢复】 重装系统导致xfs文件系统分区丢失的数据恢复案例

服务器数据恢复环境: EMC某型号存储,20块磁盘组建raid5磁盘阵列,划分2个lun。 服务器故障: 管理员执行重装系统操作后发现分区发生改变,原先的sdc3分区丢失,该分区采用xfs文件系统,存储了公司重…

开放式耳机好用吗,推荐几款不错的开放式耳机

​开放式耳机是一种新型的耳机,相比于传统的耳机,开放式耳机听歌时不需要将耳朵堵上,不会因为长时间佩戴而对听力造成损害。它不需要入耳也能听到声音,在户外运动时能够及时听到环境音,避免安全隐患。现在在骨传导市面…

nodejs+vue 沃健身房管理系统

3)系统分析 本章主要是对系统可行性、系统性能、还有系统功能需求进行分析。 (4)系统设计 对系统系统功能和数据库等进行详细讲解。 (5)系统的实现 主要对个人中心、课程分类管理、用户管理、健身器材管理、健身教练管理、预约教练管理、健身课程管理、课程订单管理、健身视频管…

ESP32学习笔记08-adc单通道数据采集

8. adc单通道数据采集 8.1RTC SAR ADC 控制器 8.2ADC相关的api 8.2.1 配置adc的位宽 esp_err_t adc1_config_width(adc_bits_width_t width_bit);width_bit :位宽 返回值 ESP_OK 配置成功 ESP_ERR_INVALID_ARG 参数错误 esp32最大的宽度的12位typedef enum {

STM8S208MB -> 寄存器方式实现对Flash的连续读写操作(IAR)

代码 File: STM8S208MB_Flash_Op.c /*file: STM8S208MB_Flash_Op.cbrief: 读写Flashdata: 2023-04-14author: ArcherQAQ */#include "STM8S208MB_Flash_Op.h" #include "stdio.h"u8 dataBuf[] {0xFF, 0xFF}; // 写入Flash的数据 u8 Rec_Buf[100] {0x00…

天猫数据分析:2023年速食品(方便面)市场数据分析

我国的方便面市场是一个比较活跃的市场,其市场规模也比较庞大。近年来,随着中国经济的发展,消费者对方便面的需求量和要求也在不断变化,因此,我国方便面市场的规模和消费者的需求环境也正在不断改变。 根据鲸参谋电商数…

Excel技能之排名,小函数很强大

你还在熬夜加班搞Excel吗? 你还在用手指,指着电脑屏幕,一行一行核对数据吗? 你还在害怕被笑而不敢问同事吗? 赶紧来学Excel,收藏加关注,偷偷地进步!日积月累,必成大器&am…

12-python内存地址

1.查看内存地址 a1 print(id(a)) # 24319294835042.数据类型 (1)不可变数据类型:数值、字符串、布尔值、元组 数据存储在计算机中的某个位置,不管赋值给谁,内存地址都相同 a"jack" b"jack" prin…

常见分布式锁3:Redis setNx

Redis实现分布式锁的核心便在于SETNX命令,它是SET if Not eXists的缩写,如果键不存在,则将键设置为给定值,在这种情况下,它等于SET;当键已存在时,不执行任何操作;成功时返回1&#x…

【python游戏】努力制造阳光,让植物有力量对抗僵尸吧~

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 晃着脑袋生产阳光的向日葵,突突突吐着子弹的豌豆射手!​ 行动迟缓种类丰富的僵尸…… 印象最深的是“僵尸吃掉了你的脑子!” 还有疯狂的戴夫,无一不唤醒着我们的童年记忆​ 山…

Hive 拉链表的两种实现方式

目录 1.什么是拉链表 2.拉链表的产生背景 2.1数据同步 2.1.1全量同步 2.1.2增量同步 2.2增量同步和拉链表 3.拉链表的实现方式 3.1数据准备 3.2思路1 3.3思路2 1.什么是拉链表 我们首先要知道,拉链表是一个逻辑上的概念。 拉链表记录的是增量数据&#x…

(链表专题) 328. 奇偶链表 ——【Leetcode每日一题】

328. 奇偶链表 给定单链表的头节点 head ,将所有索引为奇数的节点和索引为偶数的节点分别组合在一起,然后返回重新排序的列表。 第一个 节点的索引被认为是 奇数 , 第二个 节点的索引为 偶数 ,以此类推。 请注意,偶…

在 RISC-V Linux 内核中添加模块

在 RISC-V Linux 内核中添加模块 flyfish 本例以添加helloworld字符设备为例 一 源码配置 1 源码 源码文件helloworld.c拷贝到 drivers/char 目录中 源码主要是输出Hello world init 2 Kconfig 打开drivers/char 目录下的Kconfig文件 在endmenu之前加上 config HELLO…

统信UOS专业版系统安装教程 - 全盘安装UOS系统

全文导读:本文介绍了UOS系统安装(全盘安装)的过程,如果没有特殊要求,推荐安装UOS系统都采用全盘安装。 准备环境 制作好统信UOS专业版启动U盘 一台CPU频率≥2GHz、内存≥4GB、硬盘≥64GB的电脑 安装步骤 一、制作…

MySQL复合查询

文章目录一、多表查询二、自连接三、子查询1.单行子查询2.多行子查询3.多列子查询4.在 from 子句中使用子查询5.合并查询一、多表查询 在实际开发中,数据往往来自不同的表,所以需要多表查询。 对多张表做笛卡尔积,实际上就是多张表的所有记…

js 特殊对象 - 数组

1.概述 数组也是对象的一种,数组是一种用于表达有顺序关系的值的集合的语言结构,也就是同类数据元素的有序集合。 数组的存储性能比普通对象要好,在开发中我们经常使用数组来存储一些数据。但是在JavaScript中是支持数组可以是不同的元素&…

使用CH9102F平替ESP32系列下载电路中的CP2102

乐鑫官方ESP32开发板的外围电路主要包含: USB-UART电路自动下载电路RC延迟电路重启按键下载按键电源降压芯片LDO下面简单介绍一下这些电路的功能。 ESP32的USB-UART电路部分,核心芯片CP2102。其作用是将USB接口传入的D、D-信号转换为串口信号RX、TX以及…

如何与 MACOM 建立 EDI 连接?

项目背景 MACOM提供高性能射频,微波和毫米波器件,其产品广泛应用于通信,航空航天,国防和工业市场。近年来MACOM在中国地区的业务一直高速增长。 为了提高其供应链的效率和准确性,MACOM使用EDI(电子数据交…