【PyTorch】第七节:数据加载器

news2024/11/28 20:36:05

作者🕵️‍♂️:让机器理解语言か

专栏🎇:PyTorch

描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。

寄语💓:🐾没有白走的路,每一步都算数!🐾 

介绍💬

        数据是深度学习的基础,我们解决的大多数深度学习问题都是需要数据的。而每一种深度学习框架都对数据的格式有自己的要求,因此,本实验主要讲解了 PyTorch 对输入数据的格式要求,以及如何将现实中的数据处理成 PyTorch 能够识别的数据集合。

知识点🍉

  • 🍓数据的分批
  • 🍓手写字符数据的分批
  • 🍓葡萄酒数据的分批(自定义数据集的封装及分批

数据的分批

        由于深度学习中的数据量一般都是极大的,我们无法一次性将所有数据全部加载到内存中。因此,在模型训练之前,一般我们都会对训练集进行分批,将数据随机分成等量的几份,每次迭代只训练一份,如下面的伪代码所示:

# training loop
for epoch in range(num_epochs):
    # 遍历所有的批次
    for i in range(total_batches):
        batch_x, batch_y = ..
          

上面伪代码可以看出:

  • epoch 每增加一次,表示完成了所有数据的一次正向传播和反向传播
  • total_batches 表示分批后的数据集合
  • i 每增加一次,表示完成了一批数据的一次正向传播和反向传播

        那么如何对数据集合进行分批呢?我们可以自己尝试编写数据分批的代码。当然,PyTorch 中也为我们提供了相应的接口,我们可以很容易的实现数据分批的过程。PyTorch 为我们提供了 torch.utils.data.DataLoader 加载器,该加载器可以自动的将传入的数据进行打乱分批 DataLoader() 的加载参数如下:

  • dataset:需要打乱的数据集。
  • batch_size 每一批的数据条数
  • shuffle:True 或者 False,表示是否将数据打乱后再分批。

        当然利用该加载器并不仅仅是对数据进行打乱和分批的操作,该加载器还可以对数据进行格式化处理,使其能够被放入后面的神经网络模型中。接下来,让我们首先利用该加载器对 PyTorch 中自带的数据集合 MNIST 进行分批操作。 

MNIST 的分批

        首先让我们加载 PyTorch 中的自带数据集合,该数据集合存在于 torchvision.datasets 中,可以直接利用 torchvision.datasets.MNIST 获得:

import torch
import torchvision
# 将数据集合下载到指定目录下
train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=torchvision.transforms.ToTensor(),
                                           download=True)
train_dataset

        从结果可以看出该数据集合共有 60000 条数据。由于这些数据都是图片,因此分批传入内存是非常必要的。

        MNIST 数据集合是一个手写字符集合,该数据集合中存储了大量的手写字符图像。我们可以加载一张图片,观察一下:

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
a_data, a_label = train_dataset[0]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(a_label))
# 显示第一个 batch 中的第一个图像
# 原始数据是归一化后的数据,因此这里需要反归一化
img = np.array(a_data+1)*127.5
img = np.reshape(img, [28, 28]).astype(np.uint8)
plt.imshow(img, 'gray')
plt.show()

         该数据集合的分批步骤很简单,只需要将我们得到的数据集合传入 DataLoader 中即可

from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
num_epochs = 2  # 迭代次数
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        # 每 10 个批次展示一次
        if (i+1) % 100 == 0:
            print(
                f'Epoch: {epoch+1}/{num_epochs},Step {i+1}/{len(train_dataset)/100}| Inputs {inputs.shape} | Labels {labels.shape}')

          如上所示,我们对 MNIST 数据集合进行分批,且利用循环对其进行了遍历。在模型训练时,我们将内循环中的代码换成模型训练的代码即可

        上面我们展示的是 PyTorch 中自带的数据集合的分批步骤,但是该库并没有包含全世界所有的数据集合。如果我们需要将自己的数据集合进行分批,应该怎么做呢?

        在 DataLoader 传入的参数值,我们需要注意的是 dataset 参数。该参数是一个 Dataset 类,即只有继承了 PyTorch 中的 Dataset 接口的类,才能够被传入 DataLoader 中。因此,针对于自己的数据集合,我们往往需要对原始的数据进行处理将其封装成一个继承了 Dataset 的 Python 类这样,我们才能够将其传入 PyTorch 之中,进行数据的分批和模型的训练。

葡萄酒数据的分批

        接下来,让我们以葡萄酒的种类预测为例,详细的阐述自定义数据应当如何进行封装和分批。 

        首先,我们将对葡萄酒数据进行分批操作,让我们先来加载数据集合:

import pandas as pd
df = pd.read_csv(
    "https://labfile.oss.aliyuncs.com/courses/2316/wine.csv", header=None)
df

         如上表所示,第一列表示该条数据属于哪一种葡萄酒(0,1,2)。而后面 13 列的数据表示的就是葡萄酒的每种化学成分的浓度。这些化学成分分别为:酒精 、苹果酸 、灰分 、灰分的碱度、镁 、总酚、 黄酮类化合物 、非类黄酮酚 、原花色素 、颜色强度 、色相 、稀释酒的 OD280/OD315 和脯氨酸。

🌿数据封装 

        我们需要对这些数据进行分批,那么我们就需要将该数据转为 PyTorch 认识的数据集合。我们可以建立一个类 WineDataset 去继承 Dataset

如果继承了 Dataset 类,我们就必须实现下面三个函数:

  • __init__(self) :用于初始化类中所需要的一些变量。
  • __len__(self): 返回数据集合的长度,即数据量大小。
  • __getitem__(self, index)返回第 index 条数据
from torch.utils.data import Dataset


class WineDataset(Dataset):
    # 建立一个数据集合继承  Dataset 即可
    def __init__(self):
        # I初始化数据
        # 以pandas的形式读入数据
        xy = pd.read_csv(
            "https://labfile.oss.aliyuncs.com/courses/2316/wine.csv", header=None)
        self.n_samples = xy.shape[0]    # shape[0]:样本数,shape[1]:特征数

        # 将 pandas 类型的数据转换成 numpy 类型
        self.x_data = torch.from_numpy(xy.values[:, 1:]) # size [n_samples, n_features]
        self.y_data = torch.from_numpy(xy.values[:, [0]])  # size [n_samples, 1]

    # 返回 dataset[index]
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    # 返回数据长度
    def __len__(self):
        return self.n_samples


# 测试
# 创造 dataset
dataset = WineDataset()
dataset[0]

🌿数据分批 

         至此,我们将葡萄酒数据也装上了一个“外壳”使 PyTorch 能够识别出该数据集合。接下来,我们只需要利用 DataLoader 加载该数据集合即可:

import math
# 传入加载器
train_loader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
# 分批训练
# 迭代次数
num_epochs = 2
total_samples = len(dataset)
# 批次
n_iterations = math.ceil(total_samples/4)
print("该数据集合共有{}条数据,被分成了{}个批次".format(total_samples, n_iterations))
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):

        # 178 个样本, batch_size = 4, n_iters=178/4=44.5 -> 45 个批次
        if (i+1) % 5 == 0:
            print(
                f'Epoch: {epoch+1}/{num_epochs}, Step {i+1}/{n_iterations}| Inputs {inputs.shape} | Labels {labels.shape}')

         从结果可以看出,我们按照每个批次 batch_size 个,将数据集合分成了 45 个批次,并对这些批次进行了 2 次迭代。由于数据总量不是 batch_size 的整数倍,因此最后一个批次的数据量为 total_samples 除以 batc_size 的余数个。

实验总结🔑

        本实验详细的阐述了数据分批的重要性。并且以手写字符数据集和葡萄酒数据集为例,阐述了如何将一个数据集合转成 PyTorch 能够使用的数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/423150.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GitHub收藏夹分类列表

前言 GitHub是一个基于Git的国际代码托管网站,其内容质量较高,用户在浏览时经常会收藏他人的项目,当收藏的项目越来越多后,用户再想查找之前特定的收藏项目会非常困难。 因此我们希望分类管理GitHub收藏夹,值得注意的…

Golang每日一练(leetDay0034) 二叉树专题(3)

目录 100. 相同的树 Same Tree 🌟 101. 对称二叉树 Symmetric Tree 🌟 102. 二叉树的层序遍历 Binary Tree Level-order Traversal 🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一…

pdf怎么转换ppt格式,两个方法转换

PDF作为一种常用的文件格式,被大众所熟悉。虽然PDF具备的稳定性,安全性,以及很强的兼容性可以让我们更方便顺畅的阅读PDF文件,但若是有需要展示PDF文件内容的时候,其优点就没有那么凸显了,这时还是将pdf转换…

数据结构与算法基础(王卓)(25)线性表的查找(1):顺序查找(线性查找)

基本基础概念: 看这就不用去翻PPT了 查找: 根据给定的某个值,在查找表中确定一个与其关键字等于给定值的数据元素(或记录) 关键字: 用来表示一个数据元素(或记录)的某个数据项的值 主…

系统安全与应用【下】

文章目录1.开关机安全控制1.1 GRUB限制1.2 实例:GRUB 菜单设置密码2.终端登录安全控制2.1 限制root只在安全终端登录2.2 禁止普通用户登录3.弱口令检测3.1 Joth the Ripper,JR4.网络端口扫描4.1 nmap命令1.开关机安全控制 1.1 GRUB限制 限制更改GRUB引导参数 通常情…

读懂MAC地址

MAC地址是一种用于标识计算机网络设备的唯一地址。它是由48个二进制数字组成的,通常表示为12个十六进制数字,每两个数字之间用冒号或连字符分隔开。MAC地址由设备制造商在生产过程中分配,以确保网络上每个设备都有唯一的标识符。 MAC地址的规…

投影仪怎么连接电脑?快来看看这3种方法!

案例:如何连接电脑和投影仪? 【想看电影,但是电脑屏幕太小,我想把电脑上的内容通过投影仪投到大屏幕上。有小伙伴知道如何连接电脑和投影仪吗?谢谢大家!】 使用投影仪可以将电脑或其他设备上的内容投放到…

Java——矩形覆盖

题目链接 牛客在线oj题——矩形覆盖 题目描述 我们可以用 21 的小矩形横着或者竖着去覆盖更大的矩形。请问用 n 个 21 的小矩形无重叠地覆盖一个 2*n 的大矩形,从同一个方向看总共有多少种不同的方法? 数据范围:0≤n≤38 进阶&#xff1…

九龙证券|啤酒龙头一季度净利暴增逾70倍!钛白粉开启年内第三次涨价

业内人士遍及认为,随着下流需求逐渐改进,钛白粉职业景气量有望恢复。 燕京啤酒昨日晚间发布一季度成绩预告,公司2023年一季度预计完成归母净利润6200万至6600万元,同比增加7076.76%至7539.77%。对于成绩变动的原因,燕京…

K8s为什么要放弃Docker

公司定期分享整理的资料 放弃始由 https://github.com/kubernetes/kubernetes/blob/master/CHANGELOG/CHANGELOG-1.20.md#deprecation 2020 年,k8s 1.20 终于正式向 Docker “宣战”:kubelet将弃用 Docker 支持,并将在未来的版本中完全移除。…

Linux高性能服务器编程|阅读笔记:第1章 - TCP/IP协议族

简介 Hello! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出~ ଘ(੭ˊᵕˋ)੭ 昵称:海轰 标签:程序猿|C++选手|学生 简介:因C语言结识编程,随后转入计算机专业,获得过国家奖学金,有幸在竞赛中拿过一些国奖、省奖…已保研 学习经验:扎实基础 + 多做笔…

CLIMS:弱监督语义分割的跨语言图像匹配

文章目录CLIMS: Cross Language Image Matching for Weakly Supervised Semantic Segmentation摘要方法语言图像匹配框架实验结果CLIMS: Cross Language Image Matching for Weakly Supervised Semantic Segmentation 摘要 存在的问题 CAM(类激活图)通常只激活有区别的对象区…

第二章 数据类型与变量

文章目录1. 字面常量2. 数据类型3. 变量3.1 变量概念3.2 语法格式3.3 整形变量3.3.1 int整型变量(4 个字节)3.3.2 long长整型变量(8 个字节)3.3.3 short短整型变量(2 个字节)3.3.4 byte字节型变量(1个字节)3.4 浮点型变量3.4.1 do…

图片的四个角怎么做成圆弧形,2种方法分享

图片的四个角怎么做成圆弧形?今天来说一说直角和圆弧,小伙伴们可能会疑惑了,怎么今天讲的内容和之前不一样啊,不要着急干货在后面,听我娓娓道来。相信很多小伙伴们都使用过iphone手机,每年的产品发布会都会…

用高中生的思维写一篇MATLAB入门

文章目录一、简介二、MATLAB的工作界面三、基本语句1、if语句2、switch语句3、try语句4、for语句和while语句5、break语句和continus语句四、数值运算1、基本算术运算2、format命令3、关系运算4、逻辑运算5、特殊变量和常数6、数学函数五、二维平面绘图1、关于颜色和数据标记点…

微积分——用积分定义自然对数的动机

第6章 对数函数&#xff0c;指数函数和反三角函数 目录 第6章 对数函数&#xff0c;指数函数和反三角函数 6.1 引言 6.2 用积分定义自然对数的动机 内容来源&#xff1a;<> Tom M. Apostol 6.1 引言 每当有人将他的注意力集中到数量关系的时候&#xff0c;他要么是…

NIFI大数据进阶_实时同步MySql的数据到Hive中去_可增量同步_实时监控MySql数据库变化_操作方法说明_02---大数据之Nifi工作笔记0034

然后我们继续来看,如果需要同步,当然需要先开启mysqlbin log日志了 可以看到开启操作 在windows和linux上开启binlog日志 然后看一下 在windows上开启mysql的binlog的方法

ES X-Pack密码认证与用户管理

用户数据的安全性一直被人诟病且默认没有密码认证&#xff0c;Elasticsearch在6.8之前官方的X-pack安全认证功能都是收费的&#xff0c;所以很多人都采用Search Guard或者ReadOnly REST这些免费的安全插件对Elasticsearch进行安全认证。从Elasticsearch 6.8开始&#xff0c;Sec…

Java集成工作流经典案例(多个项目优化精华版)

前言 activiti工作流引擎项目&#xff0c;企业erp、oa、hr、crm等企事业办公系统轻松落地&#xff0c;请假审批demo从流程绘制到审批结束实例。 一、项目形式 springbootvueactiviti集成了activiti在线编辑器&#xff0c;流行的前后端分离部署开发模式&#xff0c;快速开发平…

爆款来袭!刷屏的Auto-GPT与ChatGPT区别,GPT成为AI领域最受关注的技术,你还在等什么?(狂飙 啊。。。Github 80k star了)

最近全网火爆刷屏的热门词auto-gpt&#xff0c;在全网站频频出现: "ChatGPT 过时了&#xff0c;Auto-GPT才是未来" "它所具备的能力主打的就是一个“自主”&#xff0c;完全不用人类插手的那种&#xff01;" 到底什么是auto-gpt? 1、Auto-GPT和ChatGP…