【PyTorch】教程:学习基础知识-(3) Datasets-DataLoader

news2024/11/17 19:41:31

Dataset & DataLoader

PyTorch 提供了两个数据处理的基本方法:torch.utils.data.DataLoader torch.utils.data.Dataset 允许使用预加载的数据集以及自己的数据。 Dataset 存储样本及其对应的标签, DataLoaderDataset 基础上封装了一个可迭代的对象,以方便访问样本。

PyTorch 提供了许多预加载的数据集(如 FashionMNIST ) 这些数据集继承了 torch.utils.data.Dataset 类,并实现了特定数据的函数。它们可以用来创建模型原型和基准测试。Image Datasets, Text Datasets, 和 Audio Datasets

Loading a Dataset (加载数据集)

下面是一个加载 FashionMNIST 数据集的例子。 FashionMNIST 数据集包含了 60000 个训练样本和 10000 个测试样本,每一个样本是 28*28 的灰度图像和对应标签(一共 10 个类别)。

import torch 
from torch.utils.data import DataLoader 
from torchvision import datasets 
from torchvision.transforms import ToTensor 
import matplotlib.pyplot as plt 

training_data = datasets.FashionMNIST(
    root="../../data",   # 存放数据的路径
    train=True,          # 是训练数据集还是测试数据集
    download=True,       # 如果存储的路径里没有数据集的话,就从网络下载数据集
    transform=ToTensor() # 数据转换
)

test_data = datasets.FashionMNIST(
    root = "../../data", # 存放数据的路径
    train=False,         # 是训练数据集还是测试数据集
    download=True,       # 如果存储的路径里没有数据集的话,就从网络下载数据集
    transform=ToTensor() # 数据转换
)

Iterating and Visualizing the Dataset (迭代和可视化数据集)

我们可以像索引列表一样对数据集进行索引,如 training_data[index], 使用 matplotlib 对数据进行可视化。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}

figure = plt.figure(figsize=(10, 10))
cols, rows = 3, 3
for i in range(1, rows * cols + 1):
    sample_idx = torch.randint(0, len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.numpy().reshape(28, 28), cmap="gray")
plt.show()

在这里插入图片描述

Creating a Custom Dataset for your files (用自己的文件定制数据集)

一个定制的数据集需要实现 3 个函数: __init__, __len__, __getitem__FashionMNIST 图片存储在 img_dir 里,它们的标签存储在 CSV 标注文件里。

下一个章节中,我们会详细分析这些函数。

import os 
import numpy as np
import pandas as pd 
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)
  
    def __getitem(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

当实例化 Dataset 对象时,__init__ 函数执行一次,需要包括包含图片和标注文件的路径,以及它们是否需要转换。

labels.csv 文件结构如下:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9

__len__

__len__ 函数返回数据中样本数量。

__getitem__

__getitem__函数从给定索引 idx 处的数据集中加载并返回一个样本。基于索引,它识别图像在磁盘上的位置,使用 read_image 将其转换为一个 tensor ,从 csv 数据中提取对应的标签,调用它们上的变换函数(如果适用),并在元组中返回 tensor 图像和相应的标签。

Preparing your data for training with DataLoaders

Dataset 每次从一个样本中提取我们数据集中的特征样本和标签用于训练模型,通常情况下,我们希望在每个 epoch 传递多个样本 “minibatches” ,打乱顺序从而可以减少模型的过拟合,并且可以加速数据的提取。 DataLoader 是一个迭代器,将这些复杂的事情抽象成一个非常简单的 API

from torch.utils.data import DataLoader 

train_dataloader = DataLoader(training_data, batch_size=4, num_workers=0)
test_dataloader = DataLoader(test_data, batch_size=4, num_workers=0)

Iterate through the DataLoader ( 通过 DataLoader 迭代取数据 )

我们已经将 dataset 加载进 DataLoader 里了,并且可以通过 dataset 迭代取数据。每个迭代返回一批 train_featurestrain_labels。 见 Samplers。

train_features, train_lables = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.shape}")
print(f"Label batch shape: {train_lables.shape}")
img = train_features[0].squeeze()
label = train_lables[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在这里插入图片描述

Feature batch shape: torch.Size([4, 1, 28, 28])
Label batch shape: torch.Size([4])
Label: 9

【参考】

Datasets & DataLoaders — PyTorch Tutorials 1.13.1+cu117 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/164105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python文本颜色设置

Python文本颜色设置实现过程:书写格式:数值表示的参数含义:常见开头格式:实例:实现过程: 终端的字符颜色是用转义序列控制的,是文本模式下的系统显示功能,和具体的语言无关。 转义序…

Acwing4699. 如此编码

某次测验后,顿顿老师在黑板上留下了一串数字 23333 便飘然而去。 凝望着这个神秘数字,小 P 同学不禁陷入了沉思…… 已知某次测验包含 nn 道单项选择题,其中第 i 题(1≤i≤n)有 ai 个选项,正确选项为 bi&…

CAS And Atomic

CAS(Compare And Swap 比较并交换),通常指的是这样一种原子操作:针对一个变量,首先比较它的内存值与某个期望值是否相同,如果相同,就给它赋一个新值,底层是能保证cas是原子性的CAS的应用 在Java 中,CAS 操作…

Android开发-AS学习(三)(布局)

相关文章链接:Android开发-AS学习(一)(控件)Android开发-AS学习(二)(控件)Android开发应用案例——简易计算器(附完整源码)二、布局2.1 Linearyout常见属性说…

测试NGINX和uwsgi.ini设置

1.uwsgi修改测试 将服务器升级到16核16G配置后,我将uwsgi.ini中的部分参数调整如下: processes 32 threads 16 结果是导致内存暴满,然后直接服务器都无法连接,导致服务器卡死。之前有博客说processes处理器*2,结果…

【阶段三】Python机器学习26篇:机器学习项目实战:LightGBM回归模型

本篇的思维导图: 项目实战(LightGBM回归模型) 项目背景 为促进产品的销售,厂商经常会通过多个渠道投放广告。本案例将根据某公司在电视、广播和报纸上的广告投放数据预测广告收益,作为公司制定广告策略的重要参考依据。 本项目应用LightGBM回归算法进行项目实战,整…

Nginx入门

介绍: 下载和安装: 安装过程: 1、因为nginx是由c语言编写的,所以需要下载gcc进行编译 yum -y install gcc pcre-devel zlib-devel openssl openssl-devel 2、下载nginx安装包 wget https://nginx.org/download/nginx-1.16.1.ta…

【Java基础知识 1】第一个Java程序(Java的第一步)

本文已收录专栏 🌲《Java进阶之路》🌲 编写一个Java程序 第一个Java程序非常简单,代码如下: public class Java01_HelloWorld {public static void main(String[] agrs){System.out.println("欢迎来到Java进阶之路&#x…

蓝桥杯 stm32 按键点灯 CubeMX

注:我们使用的是 HAL 库 文章目录前言一、按键 原理图:二、按键CubeMX配置:三、代码讲解1. 读按键:( 三行代码)2.按键消抖:3,按键点灯:总结实验效果:前言 一、按键 原理…

基于yolov5-v7.0开发构建银行卡号实例分割检测识别分析系统

在之前的文章中我们已经做了很多基于yolov5完成实例分割的项目,感兴趣的话可以自行移步阅读:《基于YOLOv5-v7.0的药片污染、缺损裂痕实例分割检测识别分析系统》《基于yolov5-v7.0开发构建裸土实例分割检测识别模型》《基于yolov5-v7.0开发实践实例分割模…

CSDN第24期周赛(记录一下)

▶ 爱要坦荡荡 (163.com) 每一道题都有思路,可是只有一道Accepted100%,一道很简单的50%,一道输出错误10%并报错Segmentation Fault,一道字符串有思路但是没时间了 一,蛇皮矩阵 Accepted 10% 报错Segmentation Fault…

2、矩阵介绍

目录 一、矩阵的构造 二、矩阵大小及结构的改变 三、矩阵下标的引用 1.矩阵下标访问单个矩阵元 2.线性引用矩阵元 3.访问多个矩阵元素 四、矩阵信息的提取 1.矩阵结构 2.矩阵大小 3.矩阵的数据类型 一、矩阵的构造 矩阵的构建方式有两种,一种与单元数组相…

【寒假每日一题】DAY.7 有序序列判断

牛客网例题:点我做题 【❤️温馨提示】先做题,再看讲解效果更佳哟 描述 输入一个整数序列,判断是否是有序序列,有序,指序列中的整数从小到大排序 或者从大到小排序(相同元素也视为有序)。输入描述: 第一行…

【C++】stack、queue的模拟实现及deque介绍

一、stack 1. 以vector作为底层容器 从栈的接口中可以看出,栈实际是一种特殊的vector,因此使用vector完全可以模拟实现stack。 由于stack的所有工作是底层容器完成的,而这种具有“修改某物接口,形成另一种风貌”的性质&#xf…

Dubbo学习

文章目录1.概念1.1 Dubbo特性1.2 设计架构2.快速开始2.1需求假设2.2.工程架构2.3 创建模块2.3.1 gmall-interface—公共接口层2.3.2 gmall-user—用户模块2.3.3 gmall-order-web—订单模块2.3.4 测试结果2.3.5 使用Dubbo改造2.3.6 注解版3.监控中心4.整合SpringBoot5.Dubbo配置…

高性能网络设计专栏-网络编程

以下是在零声教育的听课记录。 如有侵权,请联系我删除。 链接:零声教育官网 一、网络io与select,poll。epoll 网络IO ,会涉及到两个系统对象 一个是 用户空间 调用 IO 的进程或者线程,另一个是 内核空间的 内核系统&a…

K_A11_007 基于STM32等单片机驱动K型热电偶( MAX6675) 串口与OLED0.96双显示

[TOC](K_A11_007 基于STM32等单片机驱动K型热电偶( MAX6675) 串口与OLED0.96双显示) 一、资源说明 单片机型号测试条件模块名称代码功能STC89C52RC晶振11.0592MK型热电偶( MAX6675) 模块STC89C52RC驱动K型热电偶( MAX6675)模块串口与OLED0.96双显示STM32F103C8T6晶振8M/系统时…

后端架构学习

心理预期 1. 什么是后端服务的架构?怎么去理解后端架构这个词? 学习架构的目的:可以更高效的解决复杂的业务问题和技术问题。对架构设计的一知半解会导致,设计不足或者多度设计的现象。架构师思考问题的角度 按出发点划分 从系统…

Linux虚拟机忘记密码

Linux虚拟机忘记密码 使用虚拟机过程中,我们有时会忘记root的登录密码,我们需要进入救援模式去命令passwd更改新的密码。 编辑模式(e) 在首页,按住e,进入编辑模式,找到LANGzh_CN.UTF-8,在末尾加上 init/bin/sh 挂载…

学长教你学C-day10-C语言数组

“同学们,我们前面讲过了变量和数据类型,我们来复习一下,用C语言变量存储数据1~10,然后再输出。小明小红你们上黑板来写,其他人写纸上就可以。” 小明和小红走向讲台拿起粉笔写下: 小红: #incl…