08 Dataset and DataLoader (加载数据集)

news2024/9/19 10:52:58

本节内容是学习 刘二大人的《PyTorch深度学习实践》的 08 Dataset and DataLoader (加载数据集)。

上节内容:07 Multiple Dimension Input (处理多维特征的输入)-CSDN博客

这节内容:

目录

一、Epoch,Batch-Size,Iterations

二、DataLoader:batch_size=2, shuffle=True

三、How to define your Dataset ?

四、Diabetes Dataset

五、Using DataLoader

六、outputs

七、torchvision.datasets


一、Epoch,Batch-Size,Iterations

先看这两行代码:

# Training cycle
for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):

这是一个嵌套循环,每一次循环就是一个 epoch ,每一次 epoch 中,我们再执行一次循环,这次循环,每一次迭代,我们都会执行一次 Mini-Batch 。

Epoch:One forward pass and one backward pass of all the training examples. 也就是说,当所有的训练样本,都进行了一次正向的前向传播和反馈传播。

Bach-Size:The number of training examples in one forward backfard pass. 也即是说,我们每次训练的时候所用的样本数量。

Iterations:Number of passes, each pass using [ batch-size ] number of examples. 也就是说,内层循环了多少次。

假设你有10000个样本,一个 batch-size 内部有1000个样本,那么 Iterations 就是10.

二、DataLoader:batch_size=2, shuffle=True

接下来取个例子,我们让下面的这些数据 batch_size=2, 在此之前我们要将其打乱,也就是说进行 shuffle 。

三、How to define your Dataset ?

import torch
from torch.utils.data import Dataset,Dataloader  # 从utils工具箱里面引入Dataset,Dataloader 
                                                 # Dataset是个抽象的类,是不能实例化的
                                                 # Dataloader用来加载数据,可以实例化

class DiabetesDataset(Dataset):  # 自定义一个类DiabetesDataset,继承Dataset的一些基本功能
    def __init__(self):  # 初始化
        pass
    def __getitem__(self, index):  # 可以通过一个索引进行操作
        pass
    def __len__(self):
        pass

dataset = DiabetesDataset()  # 用自定义的类DiabetesDataset,把它实例化成一个数据对象
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
               # DataLoader加载器,一般情况下在初始化的时候传递四个量,第一个就是数据dataset,第二个是一个小批量batch_zise的容量,第三个是是否打乱数据,一般训练的时候,我们选择shuffle=True打乱,第四个是读取mini-batch的是时候需要几个并行进程读取数据        

extra: num_workers in Windows

如果我们直接用下面的这样的代码,也就是说直接用 DataLoader 去训练,那么将会报错。

报错原因:

(稍微了解一下原因就行了叭,只要知道正确的写法应该就可以了叭)

解决方法:

我们需要对其进行封装,可以用 if 语句或者函数进行封装,如下:

四、Diabetes Dataset

我们看一下糖尿病这个数据集的实现:

class DiabetesDataset(Dataset):
    def __init__(self, filepath):  # filepath文件地址
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)  #将文件加载进来,内部数据用逗号作为分隔符,读取一个32位的浮点数
 
        self.len = xy.shape[0]  # xy是个n行9列的矩阵,n是数据样本的数量。那么它的尺寸形状就是(N,9),所以xy.shape[0]是指的取出第0元素,也就是是说把N的值取出来

        self.x_data = torch.from_numpy(xy[:, :-1])  # 取前八列
        self.y_data = torch.from_numpy(xy[:, [-1]])  # 取最后一列

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetesDataset('diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)

五、Using DataLoader

for epoch in range(100):
    for i, data in enumerate(train_loader, 0):  # enumerate是为得到第几次迭代,(x,y)放在data中
        # 1.Prepare data
        inputs, labels = data  # inputs就是x,labels就是y
        # 2.Forward
        y_pred = model(inputs)
        loss = criterion(y_pred, labels)
        print(eopoch, i, loss.item())
        # 3.Backward
        optimizer.zero_grad()  # 优化器权重清零
        loss.backward()
        # 4.Update
        optimizer.step()  # 优化更新

六、outputs

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


class DiabetesDataset(Dataset):
    def __init__(self, filepath):  # filepath文件地址
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)  #将文件加载进来,内部数据用逗号作为分隔符,读取一个32位的浮点数
 
        self.len = xy.shape[0]  # xy是个n行9列的矩阵,n是数据样本的数量。那么它的尺寸形状就是(N,9),所以xy.shape[0]是指的取出第0元素,也就是是说把N的值取出来

        self.x_data = torch.from_numpy(xy[:, :-1])  # 取前八列
        self.y_data = torch.from_numpy(xy[:, [-1]])  # 取最后一列

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

dataset = DiabetesDataset('D:\\STUDY1\\LIUER\\08加载数据集\\diabetes08.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Model()

# 损失函数
criterion = torch.nn.MSELoss(reduction='mean')
# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(100):
    for i, data in enumerate(train_loader, 0):  # enumerate是为得到第几次迭代,(x,y)放在data中
        # 1.Prepare data
        inputs, labels = data  # inputs就是x,labels就是y
        # 2.Forward
        y_pred = model(inputs)
        loss = criterion(y_pred, labels)
        print(epoch, i, loss.item())
        # 3.Backward
        optimizer.zero_grad()  # 优化器权重清零
        loss.backward()
        # 4.Update
        optimizer.step()  # 优化更新

输出结果:

注意:在这个完整的代码里面,num_workers=0, 是单线程。

但我们这节课教的是多线程,所以记得在训练前加上 if __name__='main':

七、torchvision.datasets

torchvision内置了好多数据集(MNIST、Fashion-MNIST、EMNIST、COCO、LSUN、ImageFolder、DatasetFolder、Imageenet-12、CIFAR、STL10,都可以用下面这种结构,以 ImageFolder 为例:

example: MNIST


下节内容:09多分类问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1809167.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【quarks系列】基于Dockerfile构建native镜像

目录 Dockerfile构建代码测试 Dockerfile FROM quay.io/quarkus/ubi-quarkus-native-image:22.3-java11 AS buildWORKDIR /workspace COPY . .RUN ./mvnw -DskipTeststrue clean package -Dnative -U# Stage 2: Create the minimal runtime image FROM registry.access.redhat…

linux的持续性学习

安装php 第一步:配置yum源 第二步:下载php。 yum install php php-gd php-fpm php-mysql -y 第三步:启动php。 systemctl start php-fpm 第四步:检查php是否启动 lsof -i :9000 计划任务 作用&am…

ollama webui 11434 connection refused

报错:host.docker.internal:11434 ssl:default [Connection refused] 将/etc/systemd/system/ollama.service中加上如下红框两行 然后 systemctl daemon-reload systemctl restart ollama然后删掉之前的container。 最后 sudo docker run -d -p 4000:8080 --add-…

初识volatile

volatile:可见性、不能保证原子性(数据不安全)、禁止指令重排 可见性:多线程修改共享内存的变量的时候,修改后会通知其他线程修改后的值,此时其他线程可以读取到修改后变量的值。 指令重排:源代码的代码顺序与编译后字…

十大排序

本文将以「 通俗易懂」的方式来描述排序的基本实现。 🧑‍💻阅读本文前,需要一点点编程基础和一点点数据结构知识 本文的所有代码以cpp实现 文章目录 排序的定义 插入排序 ⭐ 🧐算法描述 💖具体实现 &#x1f…

一文了解SpringBoot

1 springboot介绍 1)springboot是什么? Spring Boot是一个用于简化Java应用程序开发的框架。它基于Spring框架,继承了Spring框架原有的优秀特性,比如IOC、AOP等, 他并不是用来代替Spring的解决方案,而是和Spring框架紧密结合,进一步简化了Spring应用的整个搭建和开发过程…

操作系统真象还原:内存管理系统

第8章-内存管理系统 这是一个网站有所有小节的代码实现,同时也包含了Bochs等文件 8.1 Makefile简介 8.1.1 Makefile是什么 8.1.2 makefile基本语法 make 给咱们提供了方法,可以在命令之前加个字符’@’,这样就不会输出命令本身…

网络分析(ArcPy)

一.前言 GIS中的网络分析最重要的便是纠正拓扑关系,建立矫正好的网络数据集,再进行网络分析,一般大家都是鼠标在arcgis上点点点,今天说一下Arcpy来解决的方案,对python的要求并不高,具体api参数查询arcgis帮助文档即可…

Java_Map集合

认识Map集合 Map集合称为双列集合,格式:{key1value,key2value2,key3value3,…},一次需要存一对数据作为一个元素。 Map集合的每个元素“Keyvalue” 称为一个键值对/键值对对象/一个Entry对象,Map集合也被叫做“键值对集合” Map集…

Simscape Multibody与RigidBodyTree:机器人建模

RigidBodyTree:主要用于表示机器人刚体结构的动力学模型,重点关注机器人的几何结构、质量和力矩,以及它们如何随时间变化。它通常用于计算机器人的运动和受力情况。Simscape Multibody:作为Simscape的一个子模块,专门用…

10.2 Go Channel

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

虚拟机调用摄像头设备一直 select timeout问题的解决

在VMware里面调用v4l2-ctl捕获图像,或者opencv的VideoCapture(0)捕获图像,或者直接调用v4l2的函数,在streamon后,调用select读取数据,均会一直提示select timeout的问题,大概率是由于USB版本的兼容性造成的…

【氵】Archlinux+KDE Plasma 6+Wayland 安装nvidia驱动 / 开启HDR

参考: NVIDIA - Arch Linux 中文维基 (其实就是把 wiki 简化了一下 注:本教程适用 GeForce 930 起、10 系至 20 系、 Quadro / Tesla / Tegra K-系列以及更新的显卡(NV110 以及更新的显卡家族),此处以 RTX3060 为例 …

PHP 寿光蔬菜大棚宣传平台-计算机毕业设计源码88288

摘 要 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于寿光蔬菜大棚宣传平台当然也不能排除在外,随着网络技术的不断成熟,带动了寿光蔬菜大棚宣传平台,它彻底…

连续状态方程的离散化例子

连续状态方程的离散化 在控制系统中,连续状态方程的离散化是一个重要的步骤,用于将连续时间系统转换为离散时间系统,以便在数字控制器中实现。这通常涉及将连续时间的微分方程转换为离散时间的差分方程。常用的离散化方法 前向欧拉法(Forward Euler)简单易实现,但精度较…

详解python中的pandas.read_csv()函数

😎 作者介绍:我是程序员洲洲,一个热爱写作的非著名程序员。CSDN全栈优质领域创作者、华为云博客社区云享专家、阿里云博客社区专家博主。 🤓 同时欢迎大家关注其他专栏,我将分享Web前后端开发、人工智能、机器学习、深…

OpenGL绘制简单图形

绘制了一个紫色矩形和一个三角形&#xff0c;代码如下&#xff1a; #include <Windows.h> #include <gl/glut.h> void display(void) {glClearColor(0.0f, 0.0f, 0.0f, 1.0f); //设置清屏颜色glClear(GL_COLOR_BUFFER_BIT); //刷新颜色缓冲区&#xff1b;glColor3f…

“程序员职业素养全解析:技能、态度与价值观的融合“

文章目录 每日一句正能量前言专业精神专业精神的重要性技术执着追求的故事结论 沟通能力沟通能力的重要性团队合作意识实际工作中的沟通案例结论 持续学习持续学习的重要性学习方法进步经验结论 后记 每日一句正能量 梦不是为想象&#xff0c;而是让我们继续前往。 前言 在数字…

Policy-Based Reinforcement Learning(1)

之前提到过Discount Return&#xff1a; Action-value Function &#xff1a; State-value Function: &#xff08;这里将action A积分掉&#xff09;这里如果策略函数很好&#xff0c;就会很大&#xff1b;反之策略函数不好&#xff0c;就会很小。 对于离散类型&#xff1a; …

Qt中解决编译中文乱码和编译失败的问题

解决方法 1.使用#pragma execution_character_set(“utf-8”) QT5中在cpp中使用#pragma execution_character_set(“utf-8”)解决中文乱码&#xff0c;不过这里要求该源代码必须保存成带Bom的utf-8格式&#xff0c;这也是有些在网上下载的代码&#xff0c;加上这句源代码后还…