《PyTorch深度学习实践》第八讲 加载数据集

news2024/12/26 0:23:01

b站刘二大人《PyTorch深度学习实践》课程第八讲加载数据集笔记与代码:https://www.bilibili.com/video/BV1Y7411d7Ys?p=8&vd_source=b17f113d28933824d753a0915d5e3a90


Dataset用于构造数据集,该数据集能够支持索引

DataLoader用于从数据集中拿出一个mini-batch来用于训练


术语:

  • epoch:训练轮数
    • 所有的训练样本都进行了前向和反向传播的一个过程
    • 所有训练样本都进行了训练
  • Batch-Size:每轮训练进行mini-batch的次数
    • 每次训练的时候所用的样本数量
  • Iterations:batch分了多少个
    • 内层的batch一共执行了多少次
image-20230701153731997

外层表示训练周期,内层是对batch进行迭代

例如有1万个样本,batch是1千个,即batch-size = 1000,iterations=10


DataLoader

  • batch_size:指定batch大小
  • shuffle:打乱数据,增强随机性

数据集要能够支持索引,即DataLoader要能够访问到里面的每一个元素,同时要能够提供长度信息,以便于DataLoader对Dataset自动进行小批量的数据集省出

首先是随机打乱数据(Shuffle),接下去Loader会对打乱后的数据进行分组,做成可迭代的Loader

image-20230701154100760

代码实现Dataset和DataLoader

  • Dataset是一个抽象类,不能实例化对象,只能继承
  • DataLoader用于帮助我们加载数据,可以实例化
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


# DiabetesDataset类继承自Dataset
class DiabetesDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):   # 为了实例化之后能够支持下标操作
        pass

    def __len__(self):  # 获取数据条数
        pass


# 实例化DiabetesDataset类对象
dataset = DiabetesDataset()

# 初始化loader,传入数据集dataset,设置batch_size,是否需要打乱数据,num_worker用于读取的时候是否要用多线程(进程数)
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)

在windows系统下使用num_worker直接去训练会有一些问题

image-20230701160512323 image-20230701160528621

使用DataLoader

image-20230701162158613
# 训练过程
# 外层表示训练周期,例如epoch取50表示所有的数据要跑50次
for epoch in range(100):
    # 内层直接对train_loader进行迭代
    # 用enumerate是为了获取当前迭代次数i,data存储train_loader的数据x和标签y,元组形式
    for i, data in enumerate(train_loader, 0):
        # 1. prepare data
        inputs, labels = data   # inputs(x)和labels(y)都是张量
        # 2. forward
        y_pred = model(inputs)  # y_hat
        loss = criterion(y_pred, labels)
        print(epoch, loss.item())
        # 3. backward
        optimizer.zero_grad()  # 在反向传播开始将上一轮的梯度归零
        loss.backward()  # 反向传播(计算梯度)
        # 4. backward
        optimizer.step()        # 更新权重w和偏置b

完整代码

image-20230701162521947
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


# DiabetesDataset类继承自Dataset
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]  # 取行数,获取数据集个数
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):   # 为了实例化之后能够支持下标操作
        return self.x_data[index], self.y_data[index]   # 返回索引

    def __len__(self):  # 获取数据条数
        return self.len


# 实例化DiabetesDataset类对象
dataset = DiabetesDataset('dataset/diabetes.csv.gz')
# 初始化loader,传入数据集dataset,设置batch_size,是否需要打乱数据
train_loader = DataLoader(dataset=dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=2)  # num_worker:读取的时候是否要用多线程(进程数)


# 定义模型
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Model()

# criterion = torch.nn.MSELoss(size_average=True) pytorch更新后被弃用了
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


# 训练过程
# 外层表示训练周期,例如epoch取50表示所有的数据要跑50次
for epoch in range(100):
    # 内层直接对train_loader进行迭代
    # 用enumerate是为了获取当前迭代次数i,data存储train_loader的数据x和标签y,元组形式
    for i, data in enumerate(train_loader, 0):
        # 1. prepare data
        inputs, labels = data   # inputs(x)和labels(y)都是张量
        # 2. forward
        y_pred = model(inputs)  # y_hat
        loss = criterion(y_pred, labels)
        print(epoch, loss.item())
        # 3. backward
        optimizer.zero_grad()  # 在反向传播开始将上一轮的梯度归零
        loss.backward()  # 反向传播(计算梯度)
        # 4. backward
        optimizer.step()        # 更新权重w和偏置b

image-20230701162905118

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/707841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入理解计算机系统(3)_计算机指令

深入理解计算机系统系列文章目录 第一章 计算机的基本组成 1. 内容概述 2. 计算机基本组成 第二章 计算机的指令和运算 3. 计算机指令 4. 程序的机器级表示 5. 计算机运算 6. 信息表示与处理 第三章 处理器设计 7. CPU 8. 其他处理器 第四章 存储器和IO系统 9. 存储器的层次…

金融基础知识(三):期权

1.认购期权与认沽期权 认购期权和认沽期权都是交易所常见的期权合约。 认购期权(Call Option)是一种给予持有人以在未来某个时间或特定事件发生时购买底层标的资产的权利。认购期权的持有人在行权日(Expiration Date)可以按照期…

B/S架构的C#云检验系统源码 实验室信息管理系统源码

科技的飞速发展为实验室信息管理带来了新机遇,云计算技术的应用更是为实验室信息管理打开了新的大门。云 LIS 实验室信息管理系统,作为一种新型的信息化管理方案,已经在多个实验室的信息化管理中得到应用,并且具有广阔的应用前景。…

Python3 命名空间和作用域 | 菜鸟教程(十七)

目录 一、命名空间 (一)简介 1、命名空间(Namespace)是从名称到对象的映射,大部分的命名空间都是通过 Python 字典来实现的。 2、命名空间提供了在项目中避免名字冲突的一种方法。 3、各个命名空间是独立的,没有任何关系的&a…

访问者模式(Vistor)

定义 访问者是一种行为设计模式,它能将算法与其所作用的对象隔离开来。 前言 1. 问题 假如你的团队开发了一款能够使用巨型图像中地理信息的应用程序。图像中的每个节点既能代表复杂实体(例如一座城市), 也能代表更精细的对象…

Nginx【Docker(安装Nginx、Nginx服务启停控制、全局块、events块、HTTP块)】(二)-全面详解(学习总结---从入门到深化)

目录 Docker安装Nginx Nginx服务启停控制 Nginx配置指令详解_全局块 Nginx配置指令详解_events块 Nginx配置指令详解_HTTP块 Docker安装Nginx 拉取官方的Nginx镜像 [rootlocalhost ~]# docker pull nginx 以下命令使用 Nginx 默认的配置来启动一个 Nginx 容器实例&#xf…

小驰私房菜_28_Qcom Camx相关名词

(Qcom 7325平台) CSID = Camera Serial Interface Decoder module IPE = Image Processing Engine IFE (x3) = Image Front End IFE_lite (x2) BPS = Bayer processing segment (for Snapshot) IPE = Image Processing Engine VPU = Video Processing Unit (CODEC) DP…

matplotlib布局模式

栅格布局 import matplotlib.pyplot as plt import numpy as np plt.figure("OBJ")x np.linspace(-np.pi, np.pi, 1000) cosy np.cos(x) siny np.sin(x) y x * 0.5 timesy x ** 2 # 创建九宫格 gs plt.GridSpec(3, 3) # 第0-1行,第2列 plt.subplot…

Eclipse中有用的快捷键

Eclipse中有的快捷键自己记不清楚,但用起来又很方便,遇到了就放在这边备忘。 【CtrlO】快速定位某个类中的属性、方法 有时候,一个类中的属性、方法比较多,想用快捷键快速查找,提升效率。 举例:我想查找…

MYSQL-聚合函数及分组查询

常用聚合函数 COUNT() 求有多少行 SUM() 求和 AVG() 求平均值 MIN() 求最小值 MAX() 求最大值 举个栗子 SELECT AVG(price) FROM products WHERE price_id > 10; 这行代码就是在求id大于10的价格的平均值 AVG(price)表示求price列的平均值 执行逻辑为 先由WHERE…

Mock在接口测试中的实际应用

关于Mock测试 01、含义和目的 1、 什么是mock测试? Mock 测试就是在测试过程中,对于某些不容易构造(如 HttpServletRequest 必须在Servlet 容器中才能构造出来)或者不容易获取的比较复杂的对象(如 JDBC 中的ResultSe…

chatgpt赋能python:下载完Python,如何进入编辑器

下载完Python,如何进入编辑器 Python是一门高级编程语言,具有简单易懂、易于学习、可拓展性强等特点,被广泛应用于Web应用、桌面应用、科学计算、人工智能等众多领域。如果你已经下载并安装了Python,那么接下来如何进入编辑器呢&…

uniapp智慧停车场系统微信小程序h5、APP源码 智能停车系统源码 安装搭建部署教程

【APP】: flutter(原生混合框架,不是web封装,原生应用,一套代码直接生成原生Android和ios应用),既不损失性能,也能降低开发成本 【小程序/h5/公众号】:uni-app(底层框架Vue) 【后台管理】:vue-e…

DeepSpeed-Chat 打造类ChatGPT全流程 笔记一

这篇文章主要是对DeepSpeed Chat的功能做了一些了解,然后翻译了几个主要的教程了解了一些使用细节。最后在手动复现opt-13b做actor模型,opt-350m做reward模型进行的chatbot全流程训练时,踩了一些坑也分享出来了。最后使用训练后的模型做servi…

计算机组成原理(课堂测验3次)

3、同步通信与异步通信的主要区别是什么,说明通信双方如何联络。 同步通信和异步通信的主要区别是:前者有公共时钟线,所有设备按统一的时序、同一的传输周期进行信息传输,通信双方按约定好的时序联络;后者没有公共时钟…

探秘直链网盘:高效传输、便捷分享的存储利器!

什么是直链网盘? 直链网盘是一种用于存储和共享文件的在线服务。它为用户提供了一个方便的方式来存储和访问他们的文件,而无需依赖本地存储设备。直链网盘的主要特点是它们可以生成直接下载链接,允许用户快速下载文件,而不需要进…

使用 Sigstore 签名的 Elastic Stack 容器镜像!

作者:Maxime Greau 软件供应链攻击不断增加。 这就是为什么这个主题是安全领导者的首要任务。 在这方面,这篇博文重点介绍了使用 Sigstore 对 Elastic Stack 容器镜像进行签名的新功能,以便: 保护 Elastic 软件供应链工作流程为…

java面试Day14

1.如何使用 Redis 实现一个排行榜? Redis实现排行榜是Redis中一个很常见的场景,主要使用的是ZSet进行实现,下面是为什么选用ZSet: 有序性:排行榜肯定需要实现一个排序的功能,在Redis中有序的数据结构有List…

Tauri:跨平台探索之旅

一、简介 Tauri 是一个跨平台 GUI 框架,与 Electron 的思想基本类似。都是属于跨平台技术的解决方案 优缺点快速分析 我们一般会把tauri作为 Electron 的替代方案,electron优点咱们不看,这里就提两个electron比较明显的问题: 安装…

高考志愿填报的个人看法,希望能对你有所启发

各省高考成绩已出,又到一年高考季。张雪峰提到:“普通家庭不要光谈理想,也要谈落地。”志愿怎样填报、选专业还是选学校、什么专业好就业、高考志愿主要看什么? 作为一名过来人,今天就站在小部分群体的角度来聊聊&…