使用PyTorch加载数据集:简单指南

news2025/1/24 8:47:57

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦前期的准备
  • 🥦基本的步骤说明
  • 🥦代码讲解+实现

🥦引言

在机器学习和深度学习中,数据集的加载和处理是一个至关重要的步骤。PyTorch是一种流行的深度学习框架,它提供了强大的工具来加载、转换和管理数据集。在本篇博客中,我们将探讨如何使用PyTorch加载数据集,以便于后续的模型训练和评估。

🥦前期的准备

在实战前,我们需要了解三个名词,Epoch、Batch-Size、Iteration
下面针对上面,我展开进行说明

  • Epoch(周期):
    定义:Epoch是指整个训练数据集被完整地前向传播和反向传播通过神经网络的一次循环。在一个Epoch内,模型将看到训练集中的每个样本一次,无论是一次完整的前向传播和反向传播,还是批量的。
    作用:一个Epoch代表了一次完整的训练周期。在每个Epoch结束后,模型参数都会被更新一次。Epoch的数量通常是一个超参数,可以控制模型的训练时间和效果。

  • Batch Size(批大小):
    定义:Batch Size是指每次迭代时用于训练模型的样本数量。在每个迭代中,模型将根据批大小从训练数据中选择一小批样本来执行前向传播和反向传播,然后更新模型参数。
    作用:Batch Size控制了每次参数更新的规模。较大的批大小可以加速训练,但可能需要更多内存。较小的批大小可以增加模型的泛化能力,但训练时间可能更长。

  • Iterations(迭代):
    定义:Iteration是指一次完整的前向传播、反向传播和参数更新。一个Iteration中,模型会处理一个Batch Size的样本。
    联系:Iterations通常用于描述在一个Epoch内,模型参数更新的次数。一个Epoch内的Iterations数量等于训练数据集的大小除以Batch Size。例如,如果你有1000个训练样本,批大小为100,那么一个Epoch包含10个Iterations(1000 / 100 = 10)。

总结一下: 一个Epoch包含多个Iterations,每个Iteration包含一个Batch Size的样本。
Batch Size决定了每次参数更新的规模,而Epoch表示整个数据集的一个完整训练周期。
训练时通常迭代多个Epochs,其中每个Epoch由多个Iterations组成,以逐渐优化模型的参数。
超参数的选择,如Epoch数量和Batch Size,会影响训练的速度和模型的性能,需要根据具体问题进行调整和优化。


在DataLoader中有一个参数是shuffle,这个参数是一个bool值的参数,如果设置为TRUE的话,表示打乱数据集在这里插入图片描述

🥦基本的步骤说明

  1. 导入必要的库
  2. 定义数据预处理转换
  3. 下载和准备数据集
  4. 创建数据加载器
  5. 数据迭代

这里介绍一下DataLoader的参数

  • dataset:这是你要加载的数据集的实例,通常是继承自torch.utils.data.Dataset的自定义数据集类或内置数据集类(如MNIST)。

  • batch_size:指定每个批次(batch)中包含的样本数。这是一个重要参数,影响了训练和推理过程中的计算效率和模型的性能。通常,你需要根据你的硬件资源和数据集大小来选择适当的批大小。

  • shuffle:布尔值,控制是否在每个Epoch开始时打乱数据集的顺序。通常,设置为True可以帮助模型更好地学习,因为它增加了数据的随机性,避免模型对数据的顺序产生过度依赖。在训练时,通常建议将其设置为True。

  • num_workers:指定用于数据加载的子进程数量。这允许在数据加载过程中并行加载数据,以提高数据加载的效率。通常,设置为大于0的值可以加速数据加载。但要注意,过高的值可能会占用过多系统资源,因此需要权衡。

  • pin_memory:如果为True,则数据加载器会将批次数据置于GPU的锁页内存中,以提高数据传输的效率。通常,在GPU上训练时,建议将其设置为True。

  • drop_last:如果为True,当数据集的大小不能被批大小整除时,将丢弃最后一个批次。通常,将其设置为True以确保每个批次都具有相同大小,这在某些情况下有助于训练的稳定性。

  • timeout:指定数据加载超时的时间(单位秒)。如果数据加载器无法在指定时间内加载数据,它将引发超时异常。这可用于避免数据加载过程中的死锁。

🥦代码讲解+实现

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) 
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
    def __len__(self): 
        return self.len
dataset = DiabetesDataset('diabetes.csv.gz') 
train_loader = DataLoader(dataset=dataset, 
batch_size=32, 
shuffle=True, 
num_workers=2)
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6) 
        self.linear2 = torch.nn.Linear(6, 4) 
        self.linear3 = torch.nn.Linear(4, 1) 
        self.sigmoid = torch.nn.Sigmoid()
    def forward(self, x):
        x = self.sigmoid(self.linear1(x)) 
        x = self.sigmoid(self.linear2(x)) 
        x = self.sigmoid(self.linear3(x)) 
        return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(100):
    for i, data in enumerate(train_loader, 0):
        # 1. Prepare data
        inputs, labels = data
        # 2. Forward
        y_pred = model(inputs)
        loss = criterion(y_pred, labels) 
        print(epoch, i, loss.item())
        # 3. Backward
        optimizer.zero_grad()
        loss.backward()
        # 4. Update
        optimizer.step()
  • 首先,导入所需的库,包括NumPy和PyTorch。这些库用于处理数据和创建深度学习模型。

  • 创建一个自定义的数据集类DiabetesDataset,用于加载和处理数据。该类继承自torch.utils.data.Dataset类,并包含以下方法:
    init:加载数据文件(假定是CSV格式),将数据分为特征(x_data)和标签(y_data),并存储数据集的长度(len)。
    getitem:用于获取数据集中特定索引位置的样本。
    len:返回数据集的总长度。

  • 创建数据集实例dataset,并使用DataLoader创建数据加载器train_loader。数据加载器用于批量加载数据,batch_size参数设置每个批次的样本数,shuffle参数表示是否随机打乱数据集顺序,num_workers参数表示并行加载数据的进程数。

  • 定义神经网络模型Model,该模型继承自torch.nn.Module。模型包含三个线性层和Sigmoid激活函数。在__init__方法中,定义了模型的层结构,而forward方法描述了数据在模型中的传递过程。

  • 创建模型实例model。

  • 定义损失函数criterion和优化器optimizer。在这里,使用了二元交叉熵损失函数(torch.nn.BCELoss)和随机梯度下降优化器(torch.optim.SGD)。

  • 进行训练循环,循环次数为100次(由for epoch in range(100)控制)。
    在内部循环中,使用enumerate(train_loader, 0)来迭代数据加载器。
    准备数据:获取输入数据和标签。
    前向传播:将输入数据传递给模型,获得预测值。
    计算损失:使用损失函数计算预测值与实际标签之间的损失。
    打印损失值:输出当前训练批次的损失值。
    反向传播:通过优化器的backward()方法计算梯度。
    参数更新:使用优化器的step()方法来更新模型参数。

这段代码演示了一个基本的二分类问题的训练过程,其中神经网络模型用于预测糖尿病患者的标签(0表示非糖尿病,1表示糖尿病)。模型的训练是通过反向传播算法来更新模型参数以减小损失。在训练循环中,你可以观察损失值的变化,以了解模型的训练进展。

在这里插入图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1094587.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络链接失败怀疑是服务器处于非正常状态?如何用本地电脑查看服务器是否正常?

网络链接失败怀疑是服务器处于非正常状态?如何用本地电脑查看服务器是否正常? 网页会出现链接失败,可以实时用cdm大法,cdm可以更好的排查字节数据的返回,可以让我们更好的要检查服务器是否处于正常状态,接下…

纯文本邮件发送:java

1.打开jdk的conf下的security文件的.security,找到并删除&#xff0c;权限问题建议复制文件修改后替换 jdk.tls.disabledAlgorithmsSSLv3, TLSv1, TLSv1.1, RC4, DES, MD5withRSA, \ DH keySize < 1024, EC keySize < 224, 3DES_EDE_CBC, anon, NULL 删除后的内容 然…

Linux服务器实验总结以及回顾(全)

Linux 一、搭建简单的论坛1、准备工作2、实现步骤2.1 挂载光盘2.2 搭建yum安装环境2.2.1 网络源&#xff1a;2.2.2 本地源 2.3 安装http服务2.4 启动http服务并配置开机自启动2.5 安装软件包:mariadb-server,php,php-mysql[php*]2.6 下载并解压论坛源码包Discuz2.7 设置selinux…

Infuse Mac视频播放器 中文

Infus是一款非常好用的播放器软件&#xff0c;它具有广泛的格式支持和强大的解码能力&#xff0c;可以播放各种视频和音频文件。同时&#xff0c;它还支持杜比视界和杜比音效&#xff0c;可以提供高品质的视听体验。此外&#xff0c;Infus还具有直观易用的用户界面和频繁的软件…

Stable Diffusion绘图,lora选择

best quality, ultra high res, (photorealistic:1.4), 1girl, off-shoulder white shirt, black tight skirt, black choker, (faded ash gray hair:1), looking at viewer, closeup <lora:koreandolllikeness_v20:0.66> 最佳品质&#xff0c;超高分辨率&#xff0c;&am…

Java调用FFmpeg

Java调用FFmpeg 1、FFmepg基础知识1.1 下载 FFmpeg1.2 FFmpeg 工具使用 2、Java使用2.1 FFmpeg源码编译2.2 Java集成FFmpeg2.2.1 JNI2.2.2 Java调用执行 FFmpeg 工具 命令 1、FFmepg基础知识 About FFmpeg ffmpeg(计算机程序) - 百度百科 FFmpeg/FFmpeg - GitHub CSDN&#xf…

接口测试如何测?最全的接口测试总结,资深测试老鸟整理...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、接口测试的流程…

软件工程与计算总结(十三)详细设计中的模块化与信息隐藏

一.模块化与信息隐藏思想 1.设计质量 好的设计要着重满足以下3方面&#xff1a;可管理性、灵活性、可理解性好的设计需要侧重于间接性和可观察性——简洁性使得系统模块易于管理&#xff08;理解和分解&#xff09;、开发&#xff08;修改与调试&#xff09;和复用。实践者都…

基于适应度相关优化的BP神经网络(分类应用) - 附代码

基于适应度相关优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码 文章目录 基于适应度相关优化的BP神经网络&#xff08;分类应用&#xff09; - 附代码1.鸢尾花iris数据介绍2.数据集整理3.适应度相关优化BP神经网络3.1 BP神经网络参数设置3.2 适应度相关算法应用 4…

【软考】9.2 串/数组/矩阵/广义表/树

《字符串》 一种特殊的线性表&#xff0c;数据元素都为字符模式匹配&#xff1a;寻找子串第一次在主串出现的位置 模式匹配算法 1. 暴力破解法&#xff08;布鲁特-福斯算法&#xff09; 主串与子串一个个匹配效率低 2. KMP算法 主串后缀和子串前缀能否找到一样的元素&#xf…

轻量化Backbone | ShuffleNet+ViT结合让ViT也能有ShuffleNet轻量化的优秀能力

视觉Transformer&#xff08;ViTs&#xff09;在各种计算机视觉任务中表现出卓越的性能。然而&#xff0c;高计算复杂性阻碍了ViTs在内存和计算资源有限的设备上的适用性。尽管某些研究已经深入探讨了卷积层与自注意力机制的融合&#xff0c;以增强ViTs的效率&#xff0c;但在纯…

KOSMOS系列

Overview 总览摘要1 引言2 KOSMOS-2.52.1 Model Architecture2.1 Image and Text Representations2.3 Pre-training Data2.4 Data Processing2.5 Filtering and Quality Control 3 Experiments3.1 Evaluation3.2 Implementation Details3.3 Results3.4 Discussion 4 Related Wo…

车载多源融合定位

终端硬件由两部分组成&#xff0c;组合导航处理板和地磁导航处理板。 组合导航处理板负责采集加速度计、陀螺、GNSS和轮速计等数据进行组合导航解算&#xff0c;差分数据通过6Q主板获取到后通过串口发送至组合导航处理板。地磁导航处理板负责地磁数据采集&#xff0c;保存至数…

嵌入式实时操作系统的设计与开发 (启动过程学习)

b Reset; b Undef; b SWI; b PreAbort; b DataAbort; b . ;保留 b IRQ; b FIQ;建立异常向量表的过程&#xff0c;其中第一个指令通常都是存放在主存的零地址的。 异常向量表存放的全是汇编跳转指令&#xff0c;这些指令从主存的零地址&#xff08;0x0&#xff09;开始连续存储在…

Ubuntu下vscode dotNet downloading的问题(Cmake代码高亮)

问题描述&#xff1a;使用Cmake Language Support插件需要安装dotnet的支持库&#xff0c;我原本已经使用apt的方式安装了&#xff0c;但是进入vscode依旧要我下载。尝试按网上的方法修改为我指定的路径&#xff1a; "dotnetAcquisitionExtension.existingDotnetPath&quo…

Vsftp安装配置(超详细版)

目录 1 FTP、Vsftp介绍 1.1 FTP介绍 1.2 Vsftp介绍 1.3 Vsftp的登录类型 2 Vsftp安装配置 2.1 更换源 2.2 安装epel源 2.3 安装Vsftpd及相关依赖 2.4 vsftpd配置文件说明 2.5 vsftpd 配置详解 2.6 备份配置文件 3 vsftpd 配置匿名用户 3.1 编辑配置文件 3.2 常用的匿名FTP配置…

传输层 | UDP协议、TCP协议

之前讲过的http与https都是应用层协议&#xff0c;当应用层协议将报文构建好之后就要将报文往下层传输层进行传递&#xff0c;而传输层就是负责将数据能够从发送端传到接收端。 再谈端口号 端口号(port)标识了一个主机上进行通信的不同的应用程序&#xff0c;在TCP/IP协议中&…

让你的服务器变成游戏世界:打造游戏化在线社区的“秘诀”

引言 假如我有一台服务器&#xff0c;我希望打造一款游戏化的在线社区。那么&#xff0c;如何打造这样一个社区程序并成功运营呢&#xff1f;让我们一起来畅想吧&#xff01; 一、确定社区的主题和目标群体 打造一款游戏化的在线社区&#xff0c;首先&#xff0c;我们需要明确…

【Android 性能优化:内存篇】——WebView 内存泄露治理

背景&#xff1a;笔者在公司项目中优化内存泄露时发现WebView 相关的内存泄露问题非常经典&#xff0c;一个 Fragment 页面使用的 WebView 有多条泄露路径&#xff0c;故记录下。 Fragment、Activity 使用WebView不释放 项目中一个Fragment 使用 Webview&#xff0c;在 Fragm…

区块链(12):java区块链项目之集群部署

选择3台服务器进行区块链项目部署 1 nginx部署页面 1.1 部署静态页面 1.2 nginx 反向代理的配置 修改nginx.conf文件 nginx 默认端口是http 80或者https443 将80代理到8080 location /blockchain {proxy_pass http://localhost:8080/blockchain;proxy_redirect default; …