【PyTorch深度学习实践】07_Dataset和Dataloader

news2024/11/17 21:37:36

文章目录

    • 1. Epoch,Iteration,Batch-Size
    • 2. Dataset 和 Dataloader
      • 2.1 Dataset
      • 2.2 Dataloader
        • 2.2.1 例子
        • 2.2.2 enumerate函数
    • 3. 完整代码

1. Epoch,Iteration,Batch-Size

参考博客

在这里插入图片描述

2. Dataset 和 Dataloader

参考博客

功能概览

在这里插入图片描述

2.1 Dataset

torch.utils.data.Dataset是一个抽象类,不可以实例化,但是可以通过构建这个抽象类的子类来创建数据集。

重要方法(且必须改写):

getitem__():传入指定的索引index后,该方法能够根据索引返回对应的单个样本及其对应的标签(以元组形式)
__len__():返回整个数据集的大小
此外,因为 Dataset 类中提供了 __add__() 方法,所以继承之后我们的数据集也会拥有此方法,从而合并数据集只需使用 + 运算即可。

代码接口

class MyDataset(Dataset):
    def __init__(self):   
        # 初始化数据集的存储路径
        # 载入数据集(转化为tensor格式)
        # ...
    
    def __getitem__(self, index):
        # 返回单个样本及其标签,后续batch由什么组成也是取决于这个是怎么设置的
        pass
    
    def __len__(self):
        # 返回整个数据集的大小
        pass

读取数据时有两个选择,一是把所有数据都加载进来(数据量较小时),另一个是定义一个列表,存放文件名,再用文件名去读文件内容,第二种留待以后实现(数据量较大时)

举例可以看参考博客。

2.2 Dataloader

绝大多数时候需要以 batch 的形式访问数据集。Dataloader 这个接口提供了这样的功能,它能够基于我们自定义的数据集将其转换成一个可迭代对象以便我们批量访问。

重要参数

在这里插入图片描述
代码示例:

train_loader = DataLoader(dataset= dataset, batch_size=32, shuffle=True, num_workers=2,drop_last=False)

这段代码可以创建一个可迭代对象

2.2.1 例子

例:数据集内容如下:

1 -14 -15
1 -1 -15
1 -11 -14
1 0 -2
0 -4 2
1 7 -2
1 -7 -17
0 9 12
0 5 -14
1 -13 13

dataloader设置:

dataloader = DataLoader(data, batch_size=3, shuffle=False, drop_last=False)

将创建的可迭代对象列表化:

list(dataloader)
# [[tensor([[-14., -15.],
#           [ -1., -15.],
#           [-11., -14.]], dtype=torch.float64),
#   tensor([1., 1., 1.], dtype=torch.float64)],
#  [tensor([[ 0., -2.],
#           [-4.,  2.],
#           [ 7., -2.]], dtype=torch.float64),
#   tensor([1., 0., 1.], dtype=torch.float64)],
#  [tensor([[ -7., -17.],
#           [  9.,  12.],
#           [  5., -14.]], dtype=torch.float64),
#   tensor([1., 0., 0.], dtype=torch.float64)],
#  [tensor([[-13.,  13.]], dtype=torch.float64),
#   tensor([1.], dtype=torch.float64)]]

可以看出,列表化后,每一个 batch 均以列表的形式存储。这说明我们可以通过 for 循环来遍历所有的 batch,具体做法如下:

for inputs, labels in dataloader:
    print(inputs, labels)
# tensor([[-14., -15.],
#         [ -1., -15.],
#         [-11., -14.]], dtype=torch.float64) tensor([1., 1., 1.], dtype=torch.float64)
# tensor([[ 0., -2.],
#         [-4.,  2.],
#         [ 7., -2.]], dtype=torch.float64) tensor([1., 0., 1.], dtype=torch.float64)
# tensor([[ -7., -17.],
#         [  9.,  12.],
#         [  5., -14.]], dtype=torch.float64) tensor([1., 0., 0.], dtype=torch.float64)
# tensor([[-13.,  13.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)

2.2.2 enumerate函数

参考博客1
参考博客2

用于记录每个batch的索引(即 iteration)
在这里插入图片描述实例:(这里为了方便展示将 batch_size 设为了1):

dataloader = DataLoader(data, batch_size=1, shuffle=True, drop_last=True)
for batch_idx, (inputs, labels) in enumerate(dataloader):
    print(batch_idx, end=' ')
    print(inputs, labels)
# 0 tensor([[-4.,  2.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 1 tensor([[ -1., -15.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 2 tensor([[ 0., -2.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 3 tensor([[ 7., -2.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 4 tensor([[ 9., 12.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 5 tensor([[  5., -14.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 6 tensor([[-11., -14.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 7 tensor([[-14., -15.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 8 tensor([[ -7., -17.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 9 tensor([[-13.,  13.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)

3. 完整代码

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

# 1. Dataset和Dataloader 准备数据集
class DiabetesDataset(Dataset):
    def __init__(self, filepath):   # 也可以删去filepath,把真正路径放到下面第一行代码的第一个参数
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]      # 例如数据集是N行(N个样本),8+1列(8个特征,1个输出),shape就是一个元组,为(N,9),shape[0]就是N
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
       return self.x_data[index], self.y_data[index]

    def __len__(self):
        # 也可以 return self.len(x_data)
        return self.len


dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset= dataset, batch_size=32, shuffle=True, num_workers=2)

# 2.设计模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 三个线性模型
        self.linear1 = torch.nn.Linear(8,6)
        self.linear2 = torch.nn.Linear(6,4)
        self.linear3 = torch.nn.Linear(4,1)
        self.sigmoid = torch.nn.Sigmoid()   # 可以构造Sigmoid,nn下的Sigmoid是一个模块,不是单纯的函数

    def forward(self,x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

# 3.损失函数和优化器
# 还是二分类,直接用BCE损失即可
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# 4.训练过程
if __name__ == '__main__':
    for epoch in range(10):
         for i, data in enumerate(train_loader,0):
        #  1.准备数据
            inputs, labels = data
        #  2. 前馈
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
        #  3. 反馈
            optimizer.zero_grad()
            loss.backward()
        # 4. 更新
            optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/160274.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年浙江建筑八大员(标准员)精选真题题库及答案

百分百题库提供建筑八大员(标准员)考试试题、建筑八大员(标准员)考试真题、建筑八大员(标准员)证考试题库等,提供在线做题刷题,在线模拟考试,助你考试轻松过关。 14.根据《施工现场临…

Electron Vue之间的通讯 自定义标题栏实现最小化全屏关闭功能

方便以后定制化使用,学习记录一下。 话不多说,先看看效果吧。 效果 版本 electron ^13.0.0 知识点 Vue 相互通讯 Electron 标题栏主要逻辑代码 新建public\preload.js文件,用于前端全局发送和监听消息。 const { contextBridge, ipcRen…

【Linux】Linux权限的理解

文章目录🎪 Linux权限的理解🚀1.shell命令及其运行原理🚀2.Linux权限概念⭐2.1 用户与root身份切换⭐2.2 用户与用户身份切换⭐2.3 单条指令提权🚀3.Linux文件权限⭐3.1 文件属性(第一个字符)⭐3.2 文件角色划分与文件属性⭐3.3 文…

线性代数第四章 向量组的线性相关性

向量组及其线性组合一.向量、向量组1.向量n个有次序的数a1,a2,...,an所组成的数组称为n维向量,这n个数称为该向量的n个分量,第i个数ai称为第i个分量n维向量可以写成一行,也可以写成一列,在没有指明是行向量还是列向量时&#xff0…

Authing 入选长城战略咨询《2022中国潜在独角兽企业研究报告》

12 月 23 日,长城战略咨询(GEI)发布《2022 中国潜在独角兽企业研究报告》(下称《报告》)。作为身份云行业领先的代表企业, Authing 凭借着过硬的技术实力和突出的创新能力,首次入选中国潜在独角…

软件测试工程师为什么要写测试用例?

软件测试工程师为什么要写测试用例?相信从事软件测试行业的从业者来讲,测试用例并不陌生。因为测试用例不仅仅是一组简单的文档,它包含前提条件、输入、执行条件和预期结果等等重要内容,并且能够完成一定的测试目的和需求。下面本…

深度学习(20)—— ConvNext 使用

深度学习(20)—— ConvNext 使用 本篇主要使用convnext做分类任务,其中使用convnext-tiny,其主要有5块 stage0stage1stage2stage3head 文章目录深度学习(20)—— ConvNext 使用Part 1 ModelPart 2 Traini…

【数据结构】一篇博客带你实现双向带头循环链表!!!(零基础小白也可以看懂)

目录 0.前言 1. 简述双向带头链表 2.双向带头循环链表的实现 2.1 设计双向带头循环链表结构体 2.2双向带头循环链表的初始化 2.3双向带头循环链表的尾插 2.4双向带头循环链表的尾删 2.5双向带头循环链表的头插 2.6双向带头循环链表的头删 2.7双向带头循环链表的插入 …

【面试题】notify() 和 notifyAll()方法的使用和区别

【面试题】notify() 和 notifyAll()方法的使用和区别 Java中notify和notifyAll的区别 何时在Java中使用notify和notifyAll? 【问】为什么wait()一定要放在循环中? Java中通知和notifyAll方法的示例 Java中通知和notify方法的示例 Java中notify和no…

22年我在CSDN做到了名利兼收

写在前面 hi朋友,我是几何心凉,感谢你能够点开这篇文章,看到这里我觉得我们是有缘分的,因着这份缘分,我希望你能够看完我的分享,因为下面的分享就是要汇报给你听的,这篇文章是在 2022 年 12 月 …

从0到1完成一个Vue后台管理项目(二十三、初代项目完成、已开源)

开源地址 项目地址 项目还在优化,会增加很多新功能,UI也会重新设计,已经在修改啦! 最近打算加一些组件、顺便分享一些好用的开源项目 现在正在做迁移到vue3TS的版本、预计年后会完事,然后迁移到vite、遇到的问题和报…

docker安装prometheus和grafana

docker安装prometheus和grafana docker安装prometheus和grafana 概念简述安装prometheus 第一步:确保安装有docker第二步:拉取镜像第三步:准备相关挂载目录及文件第四步:启动容器第五步:访问测试 安装grafana 第一步&…

分享66个ASP源码,总有一款适合您

ASP源码 分享66个ASP源码,总有一款适合您 66个ASP源码下载链接:https://pan.baidu.com/s/1Jf78pfAPaFo6QhHWWHEq0A?pwdwvtg 提取码:wvtg 下面是文件的名字,我放了一些图片,文章里不是所有的图主要是放不下...&…

Docker容器与镜像命令

文章目录帮助命令镜像命令容器命令其它命令命令总结帮助命令 显示 Docker 版本信息 docker version显示 Docker 系统信息,包括镜像和容器数 docker info 帮助 docker --help 镜像命令 列出本地主机上的镜像 docker images运行结果 REPOSITORY TAG …

Python采集彼岸4K高清壁纸

前言 嗨喽,大家好呀~这里是爱看美女的茜茜呐 又到了学Python时刻~ 环境使用: Python 3.8 解释器 Pycharm 编辑器 模块 import re import requests >>> pip install requests ( 更多资料、教程、文档点击此处跳转跳转文末名片加入君羊,找…

【Leetcode面试常见题目题解】5. 最长公共前缀

题目描述 本文是LC第14题&#xff0c;最长公共前缀&#xff0c;题目描述如下&#xff1a; 编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀&#xff0c;返回空字符串 “”。 限制 1 < strs.length < 200 0 < strs[i].length < 200 strs[i] 仅…

数据库 MySQL-window安装和卸载

安装 官网&#xff1a; MySQL :: Download MySQL Community Server 或 MySQL :: Download MySQL Community Server (Archived Versions) 文件目录简述 bin存放了可执行文件&#xff0c;docs是文档&#xff0c;include放的是c语言相关的.h文件&#xff0c;lib是c语言的库文件…

wmv是什么格式?如何录制wmv格式的视频?图文教学

很多小伙伴在使用文件的时候&#xff0c;经常会发现自己的一些文件后缀名是wmv。或者说在工作、学习的过程中&#xff0c;有过被要求使用wmv格式的文件。wmv是什么格式&#xff1f;如何录制wmv格式的视频&#xff1f;今天小编就来详细的跟大家说说。 一、wmv是什么格式&#xf…

SpringBoot复习(一)

底层注解 Configuration 自定义配置类 Bean: 可以通过Bean注解将方法的返回值交给ioc容器来管理 组件id为方法名&#xff0c;组件的类型就是方法的返回类型。 默认组件是单例的 Configuration: 告诉springboot这是一个配置类之前的配置文件 配置类本身也是组件&#xff0c;由s…

【Linux】Makefile/make - 快速理解入门

目录 一、概念理解 1、基本概念 2、举例说明 二、编写 Makefile 1、依赖关系和依赖方法 2、文件清理 3、扩展内容 一、概念理解 1、基本概念 在我们学习 Linux 的过程中&#xff0c;我们可以直接使用 gcc 指令对程序的文本文件逐个进行编译处理&#xff0c;这是因为我…