Pytorch深度学习笔记(九)加载数据集

news2024/11/20 0:49:01

目录

1.名词解释

2. 数据集加载器Dataloader

3.完整代码


推荐课程:08.加载数据集_哔哩哔哩_bilibili

1.名词解释

名词解释:Epoch,Batch,Batch-Size,Iterations

Epoch(周期):指所有的训练样本都进行一次前向和反向传播

Batch-Size(批量大小):batch进行一次前向和反向传播的样本数量

Iterations(迭代):完成一次epoch中batch的次数

2. 数据集加载器Dataloader

DataLoader是pytorch定义的数据集加载器
通过DataLoader设置mini_batch,设置DataLoader相关参数即可

DataLoader参数:

dataset:数据集

num_workers:需要几个并行的进程读取数据

batch_size:一次batch所需的样本数量

shuffle:是否打乱数据集顺序,打乱数据集有利于模型克服“鞍点问题”

只要数据集能支持索引和提供数据集长度,DataLoader就能对数据集生产batch。

处理数据集两种方法:1.数据集不够大,直接读进内存。2.数据集所占空间比较大,像图片、语音的数据集,将文件名读进内存,根据文件名加载问价。

代码实现(使用数据集加载器Dataloader加载diabetes数据集):

# 自定义数据集类
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        # 返回nbarray多维数组
        xy = np.loadtxt(filepath, delimiter=',',dtype=np.float32)
        # shape函数读取矩阵的长度,比如shape[0]就是读取矩阵第一维度的长度。
        self.len = xy.shape[0]
        # 数组切片,x[start:stop:step],step为负值表示为逆序。from_numpy函数,numpy转torch
        self.x_data = torch.from_numpy(xy[:,:-1])
        # [-1] 最后得到的是个矩阵
        self.y_data = torch.from_numpy(xy[:, [-1]])

    # 通过索引拿数据
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    # 返回数据集长度
    def __len__(self):
        return self.len


dataset = DiabetesDataset('dataset/diabetes.csv')
# DataLoader是pytorch定义的数据集加载器
# dataset数据集,batch_size小批量所需的数据量,shuffle是否要打乱数据集,num_workers需要几个并行的进程读取数据
# 通过DataLoader设置mini_batch
train_loader = DataLoader(dataset=dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=2)

3.完整代码

import numpy as np
import torch
# Dataset是一个抽象类,我们必须自定义数据集类并继承这个类
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

#……1.准备数据集……………………………………………………………………………………………………………………………#
# 自定义数据集类
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        # 返回nbarray多维数组
        xy = np.loadtxt(filepath, delimiter=',',dtype=np.float32)
        # shape函数读取矩阵的长度,比如shape[0]就是读取矩阵第一维度的长度。
        self.len = xy.shape[0]
        # 数组切片,x[start:stop:step],step为负值表示为逆序。from_numpy函数,numpy转torch
        self.x_data = torch.from_numpy(xy[:,:-1])
        # [-1] 最后得到的是个矩阵
        self.y_data = torch.from_numpy(xy[:, [-1]])

    # 通过索引拿数据
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    # 返回数据集长度
    def __len__(self):
        return self.len


dataset = DiabetesDataset('dataset/diabetes.csv')
# DataLoader是pytorch定义的数据集加载器
# dataset数据集,batch_size小批量所需的数据量,shuffle是否要打乱数据集,num_workers需要几个并行的进程读取数据
# 通过DataLoader设置mini_batch
train_loader = DataLoader(dataset=dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=2)

#…2.设计模型………………………………………………………………………………………………………………………………………#
# 继承torch.nn.Module,定义自己的计算模块,neural network
class Model(torch.nn.Module):
    # 构造函数
    def __init__(self):
        # 调用父类构造
        super(Model, self).__init__()
        # 从8维降到6维再降到4维再降到1维
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    # 前馈函数
    def forward(self, x):
        # 调用self.sigmoid,并linear
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

#……3.构造损失函数和优化器………………………………………………………………………………………………………#
# 实例化自定义模型,返回做logistic变化(也叫sigmoid)的预测值
model = Model()
# 实例化损失函数,返回损失值
criterion = torch.nn.BCELoss(size_average=True)
# 实例化优化器,优化权重w
# model.parameters(),取出模型中的参数,lr为学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


#……训练周期……………………………………………………………………………………………………………………………………………#
if __name__ == '__main__':
    for epoch in range(100):
        # 迭代train_loader
        # 根据自定义数据集类返回的data包含(x_data,y_data),enumerate能够获取是第几次迭代
        for i, data in enumerate(train_loader, 0):
            # 1.准备数据
            inputs, labels = data
            # 2.正向传播
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
            # 3.反向传播
            optimizer.zero_grad()
            loss.backward()
            # 4.更新权重w
            optimizer.step()

练习:构造一个分类模型使用titanic数据集

Titanic - Machine Learning from Disaster | Kaggle

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/454785.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

职工管理系统(C++)

职工管理系统有以下8个功能: - 增加职工信息:实现批量添加职工功能,将信息录入到文件中,职工信息为:职工编号、姓名、部门编号 - 显示职工信息:显示公司内部所有职工的信息 - 删除离职职工:按照…

java防止重复提交的方法

为了防止重复提交,可以采用以下几种方法: 1. 令牌机制(Token) 在表单中添加一个隐藏字段,用于存放一个随机生成的令牌(Token)。当用户提交表单时,将令牌一起提交到服务器。服务器接收…

Win10系统重装过程(一键装机)

相信不少小伙伴都有刷机重装系统的过程,那种镜像,up盘,压缩包等多个复杂过程也折磨的大伙不堪重负,因此本期带来简易版一键装机相应操作。 下载地址: 小心点击下方链接,点击即下载(3.66GB&…

SAM:图像分割的里程碑

Facebook的新模型称为SAM或Segment Anything Model,具有在计算机视觉行业中引起积极变革的潜力。这个突破性模型不同于以前使用的任何其他图像分割模型。 传统上,会为不同类型的图像,如人或汽车,分别训练不同的模型,但…

成功上岸国防科大!

Datawhale干货 作者:王洲烽,太原理工大学,Datawhale成员 写在前面 相比较于一般的经验贴,我更想在这里讲述一下自己的故事。我一开始报考的是北理工,但很遗憾9月份北理改考408了,无缘京爷,所以…

路径规划 | 图解概率路图PRM原理及其参数分析

目录 0 专栏介绍1 基于采样的规划算法2 概率路图基本原理3 PRM算法流程4 PRM参数分析4.1 采样点数4.2 阈值 d max ⁡ \mathrm{d}_{\max} dmax​ 0 专栏介绍 🔥附C/Python/Matlab全套代码🔥课程设计、毕业设计、创新竞赛必备!详细介绍全局规划…

nginx简单介绍

文章目录 1. 下载并解压2. 80端口被占用,更改nginx默认的监听端口3. 访问nginx4. 在linux上安装nginx5. nginx常用命令6. nginx.conf 1. 下载并解压 官网下载 2. 80端口被占用,更改nginx默认的监听端口 更改conf/nginx.conf文件 3. 访问nginx ht…

[译] 实战 React 18 中的 Suspense

> 原文:https://dev.to/darkmavis1980/a-practical-example-of-suspense-in-react-18-3lln React 18 带来了很多变化,它不会破坏你已经编写过的代码,并且有很多改进和一些新概念。 它也让很多开发人员,包括我,意识到…

vue---mixin混入

一个混入对象可以包含任意组件选项(如data、methods、created、mounted等等)。当组件使用混入对象时,所有混入对象的选项将被“混合”进入该组件本身的选项。 我们可以使用混入,向组件注入自定义的行为。 和组件注册和指令一样 vu…

Oracle Linux 9 上基于 CRI-O 安装 Kubernetes 1.27 集群

Oracle Linux 9 上基于 CRI-O 安装 Kubernetes 1.27 集群 1. 禁用 swap2. 禁用防火墙3. 将 SELinux 设置为 permissive 模式4. 安装cri-o5. 安装kubelet kubeadm kubectl6. 更新模块设置7. 初始化Kubernetes集群8. 配置集群访问9. 安装网络插件10. 验证集群 1. 禁用 swap sudo…

docker容器原样迁移完整过程(nignx例子)

我们在测试服务器上,辛辛苦苦开发,各种配置好了服务,然后想着傻瓜式的迁移部署。接下来的就是干货了 过程描述: 为了体现一个完成性的描述,我们最初拉镜像开始,一直说到迁移后的服务正常运行。 接下来以ng…

centos7 查看服务器配置信息

1.linux查看版本当前操作系统发行信息 cat /etc/centos-release cat /etc/centos-release 2、查看内核版本uname -a或者cat /proc/version 3、查看CPU参数 1)、查看 CPU 物理个数   grep physical id /proc/cpuinfo | sort -u | wc -l 2)、查看 CPU …

如何降低小程序开发费用:从项目管理到技术选型

小程序的开发费用是许多企业和初创公司的瓶颈。在本文中,我们将介绍如何通过项目管理和技术选型来降低小程序开发费用,让您的企业更加高效。我们会详细阐述如何在项目管理中制定清晰的项目计划、与开发团队密切合作、采用敏捷开发方法。在技术选型方面&a…

jmeter压测结果分析

jmeter结果查看主要在结果树和聚合报告,实际在做压测过程中不做可视化操作,用命令行执行,再查看测试报告。 python在本地起服务 cmd打开命令框执行语句:python -m http.server 9090(端口号,可自定义&…

magento webapi 接口返回 json对象

前言 现在主流的项目开发都是前后端分离,数据通过json对象格式进行传输。但是magento框架,和传统PHP框架相比,区别很大。虽然也支持以RestApi的形式传输数据,但是要么格式并非是传统jsonObject要么就是需要大量的get、set方法。本…

TypeScript学习笔记以及学习中遇到的问题

本笔记是来自翻阅xcatliu的typeScript入门教程文档、TypeScript官方文档的部分摘录、以及观看B站学习视频进行笔记记录与知识点补充、本人实际使用时遇到的问题与解决记录、碎片化接触到相关知识点合并整理而成 仅供本人洪的学习使用 hello TypeScript 一、TypeScript安装 Ty…

【QT】如何检测目录或文件中的内容被修改,可以使用QFileSystemWatcher类进行检测

目录 1. QFileSystemWatcher类的介绍2. QFileSystemWatcher的公共函数2.1 构造函数2.2 析构函数2.3 添加监控的路径2.4 返回正在监控的目录或文件2.5 从文件系统监视程序中删除指定的路径 3. QFileSystemWatcher的信号4. 测试代码4.1 操作步骤4.2 MainWindow.h4.3 MainWindow.c…

Spring之Bean的配置与实例

Spring之Bean的配置与实例 一、Bean的基础配置1. Bean基础配置【重点】配置说明代码演示运行结果 2. Bean别名配置配置说明代码演示打印结果 3. Bean作用范围配置【重点】配置说明代码演示打印结果 二、Bean的实例化1. Bean是如何创建的2. 实例化Bean的三种方式2.1 构造方法方式…

数据库系统-数据库查询实现算法之

文章目录 一、一趟扫描算法1.1 算法概述1.2 算法逻辑&物理实现1.2.1 逻辑层面1.2.2 物理层面1.2.2.1 P11.2.2.2 P21.2.2.3 P31.2.2.4 P4 1.3 迭代器构造查询实现算法1.4 关系操作的一趟扫描算法1.4 基于索引的查询实现算法 二、两趟扫描算法2.1 两趟算法基本思想2.2 多路归…

SaaS是什么?企业为什么要有SaaS系统?

什么是SaaS系统?企业为什么要有SaaS系统? 近几年,SaaS突然变成了一个热门词汇,无论是一些权威报告,还是知乎上知友们热烈的讨论,对于Saas系统可谓是各有各的见解和看法。 今天就综合几位答主的观点&#…