神经网络和深度学习-加载数据集DataLoader

news2025/1/12 13:27:09

加载数据集DataLoader

Dataloader的概念

dataloader的主要目标是拿出Mini-Batch这一组数据来进行训练

在处理多维特征输入这一文章中,使用diabetes这一数据集,在训练时我们使用的是所有的输入x,在梯度计算采用的是随机梯度下降(SDG),每次选用一个样本来进行梯度计算,但存在缺点,优化时间过长

而在Mini-Batch中我们选择小批量中的所有样本,可以最大化的利用向量的优势,来提升计算速度

在使用Mini-Batch我们要了解三个概念

  • Epoch

  • Batch-Size

  • Iterations

首先我们来看一下Epoch,我们采用Mini-Batch之后要使用一个嵌套循环,内循环是每一次迭代都执行一个Mini-Batch,这两个循环相当于把所有的Mini-Batch都跑了一遍

在这里插入图片描述

Epoch的定义就是:所有训练样本都进行一次前向传播和反向传播的过程

Batch-Size的定义是:进行一次前馈和反馈的训练样本数量

Iterations的定义是:所有的样本/Batch-Size

Dataloader的作用

我们要做小批量的训练时,要确定一些重要的参数

  • batch-size

  • shuffle:打乱顺序,为了提高数据样本的随机性可以选择对数据集进行shuffle

  • num_workers :并行操作的数量

  • [i]:支持索引

  • len:长度

在这里插入图片描述

定义Dataset和DataLoader

我们来看一下代码中是如何定义dataset的,在torch.utils.data工具包中包含了这两个类

其中dataset是一个抽象的类,不能实例化,只能被其他的子类继承,我们想要使用的时候必须定义一个自己 的类来继承使用

dataloader是用来帮助加载数据的,我们可以实例化一个dataloader

例如下面自定义一个DiabetesDataset的类

在这里插入图片描述

getitem这个方法是一个模板方法,是为了实例化这个对象之后能够支持下标操作,通过索引来取出数据

len这个方法同样是模板方法,为了返回数据集中的数据条数

接下来就可以用自定义的DiabetesDataset类来实例化dataset对象

我们在构造数据集的时候一般有两种选择

  • 把所有数据在init中加载进来,放入内存中,再用getitem根据索引传出数据,适用于数据集本身的容量不大

  • 类似于图像、语音这种非结构的大数据集,不能一次性加载到内存中时,定义一个列表,数据集里面得每一条数据的文件名放入相应的列表中

📌我们在windows中使用num_workers进行训练会报错,原因是在windows下和Linux下的进程库是不一样的。所以用spawn替代了fork,所以其中处理的方式不同,会出现RuntimeError

📌解决方法:将要训练的代码train_loader进行封装起来(if语句或者是函数中)

在这里插入图片描述

我们在代码中进行改动

在这里插入图片描述

数据集的实现

在构造函数中我们需要一个filepath:描述文件来自什么地方,其次需要通过self.len来获取数据集的长度

在这里插入图片描述

DataLoader的使用

使用enumerate可以获得当前迭代的次数,train_loader中拿出来的元组(x,y)放入data中,所以在训练之前把inputs(x_data)和labels(y_data)从data中取出,此时这两个数据都是Tensor。

也可以一开始就在for循环中使用i,(x,y),就可以省去下面那句

在这里插入图片描述

完整代码

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


# prepare dataset
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]  # shape(多少行,多少列)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)  # num_workers 多线程


# design model using class
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Model()


# construct loss and optimizer
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


# training cycle forward, backward, update
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0):  # train_loader 是先shuffle后mini_batch
            # 1. prepare data
            x_data, y_data = data
            # 2. Forward
            y_pred = model(x_data)
            loss = criterion(y_pred, y_data)
            print(epoch, i, loss.item())
            # 3. Backward
            optimizer.zero_grad()
            loss.backward()
            # 4. Update
            optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/45457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

本地启动springboot项目失败端口问题

异常关键字: Cannot assign requested address: bind 排查结果 配置了环境变量【SERVER_ADDRESS】 网上搜了有的回答是端口占用,是不对的,端口占用的异常是这个【Web server failed to start. Port 8282 was already in use.】 排查结果…

麦芽糖-聚乙二醇-阿霉素maltose-Doxorubicin

麦芽糖-聚乙二醇-阿霉素maltose-Doxorubicin 中文名称:麦芽糖-阿霉素 英文名称:maltose-Doxorubicin 别称:阿霉素修饰麦芽糖,阿霉素-麦芽糖 还可以提供PEG接枝修饰麦芽糖,麦芽糖-聚乙二醇-阿霉素,Doxorubicin-PEG-…

无线传感器网络:定位、安全与同步

文章目录LocalizationRanging TechniquesReceived Signal Strength (RSS)Time of Arrival (ToA)Time Difference of Arrival (TDoA)Angle of Arrival (AoA)Range-Based Localization ProtocolsTriangulationTrilaterationIterative and Collaborative MultilaterationSecurityC…

sipp: bind_local;watchdog timer trip

文章目录作为服务端时,source ip 随机的问题命令示例bind_localwatchdog_minor_maxtriggers作为服务端时,source ip 随机的问题 https://sipp.sourceforge.net/doc/reference.html https://github.com/SIPp/sipp/issues/83 https://github.com/SIPp/sip…

GC2是什么工具

GC2是一款功能强大的命令控制应用工具,该工具将允许广大安全研究人员或渗透测试人员使用Google Sheet来在目标设备上执行远程控制命令,并使用Google Drive来提取目标设备中的敏感数据。 值得一提的是,该工具可以直接提供命令控制服务&#x…

[附源码]计算机毕业设计springboot高校社团管理系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

功率放大器的类型和特点是什么(功率放大器使用注意事项有哪些)

放大器一般泛指放大某物的装置,应用在音频、电子等领域。在包括音频功率放大器在内的电路里,各种信号都可以是电信号来进行交换。功率放大器主要是放大电路中流动信号,可以放大输入的电压或者电流。它的作用是放大各种传感器输出电信号&#…

架设好传奇登录器显示无法连接服务器,完美登录器使用常见问题解决办法

中国传奇网已经更新好了完美登陆器,已经可以下载了,完美登陆器是一款完全免费无限制的登陆器。在这里站长也推荐大家使用这个传奇登陆器。毕竟是免费的。 一.登录器域名绑定,生成登录器等问题解决办法! 问1.完美登录器绑定域名还是绑定服务器等&#xf…

数字孪生技术栈的应用场景的优点

技术栈是一个IT术语,本意是指某项工作需要掌握的一系列技能组合的统称。那么对于如今炙手可热的数字孪生技术而言,数字孪生技术栈都会包括哪些底层技能?它又是如何构成和运行的呢? 北京智汇云舟科技有限公司成立于2012年&#xff…

太卷了,这份Java性能调优手册仅上线1小时,竟被恶意封杀下架

在各大厂的面试中,性能优化的问题肯定不会缺席,这足以说明其重要性。今天给大家带来的便是由资深程序员葛一鸣老师写的《Java程序性能优化实战》,同样是没有开源版本,我会将领取方式放在文末 Java程序性能优化实战 我看过几篇讲…

[附源码]Python计算机毕业设计Django仓库管理系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

Rockwell EDI 855 采购订单确认报文详解

罗克韦尔自动化与国内12 家授权分销商,124 家认可的系统集成商,30多家亚太区的Encompass战略合作伙伴和全球战略联盟,共同为制造业企业提供广泛的世界一流的产品、解决方案与服务支持。 近期我们帮助客户成功与罗克韦尔Rockwell建立EDI连接&a…

自建云服务计费系统

自从Laxcus分布式操作系统正式开源两个月以来,可能是它一站式云计算平台属性和超大规模计算能力,给用户带来极大的便利,下载量一直持续增加,最近网站后台总是有用户在问,在Laxcus分布式操作系统的社区版本基础上&#…

地级市-空气流动系数数据-更新至2019(含10米风速、边界高度等)

1、数据来源:参考论文计算,详情请见指标说明 2、时间跨度:2000-2019年 3、区域范围:全国所有地级市 4、指标说明: 空气流动系数数据为环境经济学常用工具变量! 数据为复旦大学陈诗一和陈登科教授&…

vite+ts-4-ORM框架sequelize实现mysql操作

random recording 随心记录 What seems to us as bitter trials are often blessings in disguise. 看起来对我们痛苦的试炼,常常是伪装起来的好运。 使用ORM框架sequelize完成Mysql数据库操作 使用ts实现mysql配置/泛型重载 配置接口实现 创建src/config/DbConf…

【SQL Server + MySQL三】数据库设计【ER模型+UML模型+范式】 + 数据库安全性

极其感动!!!当时学数据库的时候,没白学!! 时隔很长时间回去看数据库的笔记都能看懂,每次都靠这份笔记巩固真的是语雀分享要花钱,要不一定把笔记给贴出来(;༎ຶД༎ຶ) ,除…

SFTP的基本定义、用途以及基本优势有哪些

文件传输协议允许用户通过Internet在远程系统之间传输数据。SFTP 就是这样一种协议,它为用户提供了一种安全的方式来发送和接收文件和文件夹,目前少数虚拟主机提供商会提供这项服务。本文将介绍SFTP的基本定义、用途和数据安全方面的优势。 SFTP(Secure …

[附源码]Python计算机毕业设计Django春晓学堂管理系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

[附源码]计算机毕业设计springboot海南琼旅旅游网

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

基于SSM的医院医疗管理系统的设计与实现

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…