pytorch学习------数据集的使用方式

news2025/1/12 23:41:31

一、前言

在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。

所以,接下来我们来学习pytorch中的数据加载的方法。

二、数据集类

2.1、Dataset基类

在torch中提供了数据集的基类torch.utils.data.Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。
torch.utils.data.Dataset的源码如下:

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:

  1. __len__方法,能够实现通过全局的len()方法获取其中的元素个数
  2. __getitem__方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据
  3. __add__方法不用实现,它是将多条数据合并
2.2、例子

下面通过一个例子来看看如何使用Dataset来加载数据

数据来源:http://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection

数据介绍:SMS Spam Collection是用于骚扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。正常短信和骚扰短信保存在一个文本文件中。 每行完整记录一条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信

数据实例:
在这里插入图片描述
代码如下:

from torch.utils.data import Dataset

data_path = r"D:\djangoProject\practice\SMSSpamCollection"

#定义数据集类
class MyDataset(Dataset):   #继承Dataset类
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()
    def __getitem__(self, index):
        #获取索引对应位置的一条数据
        #将标签和文本分开
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()    #strip()为了去点换行符
        content = cur_line[4:].strip()

        return  label,content   #返回元组的形式

    def __len__(self):
        #返回数据总量
        return len(self.lines)


if __name__ == '__main__':
    my_data = MyDataset()
    print((my_data[0]))
    print(len(my_data))

效果如下:
在这里插入图片描述

三、数据加载器

使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程 multiprocessing 并行加载数据。

在pytorch中torch.utils.data.DataLoader提供了上述的所用方法

DataLoader的使用方法示例:
我们将上述代码进行修改

from torch.utils.data import Dataset,DataLoader

data_path = r"D:\djangoProject\practice\SMSSpamCollection"

#定义数据集类
class MyDataset(Dataset):   #继承Dataset类
    def __init__(self):
        self.lines = open(data_path,encoding='utf-8').readlines()
    def __getitem__(self, index):
        #获取索引对应位置的一条数据
        #将标签和文本分开
        cur_line = self.lines[index].strip()
        label = cur_line[:4].strip()    #strip()为了去点换行符
        content = cur_line[4:].strip()

        return  label,content   #返回元组的形式

    def __len__(self):
        #返回数据总量
        return len(self.lines)

my_data = MyDataset()
data_load = DataLoader(dataset=my_data,batch_size=2,shuffle=True,num_workers=2)   #使用数据加载器

if __name__ == '__main__':
    #两次的数据都不一样   是因为shuffle的原因,打乱了数据的顺序
    for i in data_load:
        print(i)
        break

    for i in data_load:
        print(i)
        break

其中参数含义:

  1. dataset:提前定义的dataset的实例
  2. batch_size:传入数据的batch的大小,常用128,256等等
  3. shuffle:bool类型,表示是否在每次获取数据的时候提前打乱数据
  4. num_workers:加载数据的线程数

效果如下:
在这里插入图片描述
两次索引一样,单打印的数据是不一样的,因为使用shuffle打乱了数据,且每个元组的大小为2,是batch_size为2的原因

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1035116.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【力扣1464】数组中两元素的最大乘积

👑专栏内容:力扣刷题⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、题目描述二、题目分析1、排序2、最值模拟 一、题目描述 题目链接:数组中两元素的最大乘积 给你一个整数数…

针对 SAP 的增强现实技术

增强现实技术是对现实世界的一种交互式模拟。这种功能受到各种企业和制造商的欢迎,因为它可以减少生产停机时间、快速发现问题并维护流程,从而提高运营效率。许多安卓应用都在探索增强现实技术。 使用增强现实技术(AR)的Liquid U…

【Vue入门】语法 —— 事件处理器、自定义组件、组件通信

目录 一、事件处理器 1.1 样式绑定 1.2 事件修饰符 1.3 按键修饰符 1.4 常用控制符 1.4.1 常用字符综合案例 1.4.2 修饰符 二、自定义组件 2.1 组件介绍及定义 2.2 组件通信 2.2.1 组件传参(父 -> 子) 2.2.1 组件传参(子 ->…

【Leuanghing】ANSA入门——GUI界面介绍

【Leuanghing】ANSA入门——软件介绍 大家好!今天为大家推荐一款软件 —— ANSA。ANSA是目前公认的全球最快捷的CAE前处理软件之一,也是一个功能强大的通用CAE前处理软件。 一、相关介绍 BETA CAE Systems S.A公司总部位地希腊的赛萨罗尼奇市(Thessalon…

水平基准和垂直基准

水平基准 针对不同的地理空间测量,水平基准可以是一个参考点或一个参考点集。在地球上,相同的位置上的点可能有不同的坐标,取决于使用了哪种基准。 水平基准用于测量地球上的位置(position)。因为地球不是一个完美的球…

linux————zabbix搭建

目录 一、zabbix的概述 二、构成 一、server 二、web页面 三、数据库 四、proxy 五、agent 三、zabbix监控对象 四、zabbix的常用术语 五、zabbix监控框架 一、zabbix——client架构 二、zabbix_proxy_client架构 六、zabbix部署 安装zabbix5.0存储库 ​编辑​编…

flume安装及实战

flume官方下载地址:Welcome to Apache Flume — Apache Flume 一、flume安装 (1)解压至安装目录 tar -zxf ./apache-flume-1.9.0-bin.tar.gz -C /opt/soft/ (2)配置文件flume-env.sh cd /opt/soft/flume190/conf …

Visual Studio 功能增强:CMake 目标视图

Visual Studio 中的 CMake 目标视图,允许你按 CMake 目标可视化 CMake 项目结构,并生成指定的目标库和可执行文件。 为了使此视图更易于使用,我们实施了一些新的改进,使导航 CMake 目标比以往任何时候都更容易。这包括改进了到 C…

网络安全——黑客(自学)

想自学网络安全(黑客技术)首先你得了解什么是网络安全!什么是黑客!!! 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队…

通过人才测评系统来帮助Team Leader组建团队

作为HR除了对外招聘、培训、员工关系以外,如果有部门立项需要进行团队建设,也需要HR向领导推送相关简历,看似简单的过程,实际上也是困难重重,因为大多团建需求 都不是那么详细,然后挑简历的时候&#xff0c…

9.22数电(触发器寄存器一些电路分析reg的思考)

用作存储元件的电路 新输入的信号可能使电路保持同样的状态也可能使电路进入另一种新的状态 输入信号置位与复位可以用于改变构成存储元件的电路状态 RS锁存器 通过或非门就是说输入信号中有一个是1,输出就的是0 在RS0时,RS对或非门的结果无影响&am…

【多态】为什么析构函数的名称统一处理为destructor?

析构函数的名称统一处理为destructor的目的是为了解决析构函数的重写。 而这又引出了一个问题&#xff1a;为什么要进行析构函数的重写&#xff1f; 是为了下面这种情况&#xff1a; class Person { public:~Person() { cout << "~Person" << endl; } }…

java:java.util.MissingResourceException: Cant find bundle for base name解决方式

java&#xff1a;java.util.MissingResourceException: Cant find bundle for base name解决方式 1 前言 代码执行如下&#xff1a; ResourceBundle.getBundle("res.Message",Locale.getDefault(), ReadMyProps.class.getClassLoader());或 ResourceBundle.getBu…

WPF Panel笔记

1、StackPanel默认垂直排列&#xff0c;不会自动换行&#xff0c;展示可能不全。改变内元素的排列方式&#xff0c;需要用orientation属性 2、wrapPanel默认横向排列&#xff0c;会自动换行。改变内元素的排列方式&#xff0c;需要用orientation属性。 3、DockPanel默认属性&am…

获取文件上次访问时间

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl Java源码 public void testGetFileTime() {try {String string "E://test.txt";File file new File(string);Path path file.toPath();BasicFileAttributes ba…

07_ElasticSearch:倒排序索引与分词Analysis

07_ElasticSearch&#xff1a;倒排序索引与分词Analysis 一、 倒排索引是什么&#xff1f;1.1 通过示例&#xff0c;简单理解下1.2 核心组成 二、倒排索引是怎么工作的&#xff1f;2.1 创建倒排索引2.2 倒排索引搜索 三、Analysis 进行分词3.1 Analyzer 由三部分组成3.2 Analyz…

ARP协议-介于数据链路层和网络层之间的协议

通过上一篇 IP协议 我们知道 目标IP目标网络 目标主机 &#x1f64b;‍ 也就是说 必须知道 接收方的接收方的 MAC地址 > 没有MAC地址无法封装 MAC帧 在网络层&#xff0c;我们可以知道目标主机的 IP 地址&#xff0c; 但是 我们不知道对方的MAC地址 。 在同一个网段&…

Win10安装Docker Desktop并运行Tutorial示例

背景 前段时间一个项目需要在开发环境直接使用 Docker &#xff0c;为了省事便计划在本地安装 Desktop 版的 Docker 。其实安装过程比较简单&#xff0c;可视化安装即可&#xff0c;主要是对安装与初步使用时遇到的问题做个记录。 下载安装 下载地址&#xff1a;https://dow…

体验facechain

安装Anaconda 下载页面&#xff1a;https://www.anaconda.com/download/ Clone代码仓 git clone https://github.com/modelscope/facechain.git --depth 1 GIT_LFS_SKIP_SMUDGE1安装依赖 cd ./facechain pip install -r requirements.txt pip install -U openmim 运行 修改…

【C++代码】找树左下角的值,路径总和,从中序与后序遍历序列构造二叉树,从前序与中序遍历序列构造二叉树--代码随想录

题目&#xff1a;找树左下角的值 给定一个二叉树的 根节点 root&#xff0c;请找出该二叉树的 最底层 最左边 节点的值。假设二叉树中至少有一个节点。 题解 使用 height 记录遍历到的节点的高度&#xff0c;curVal 记录高度在 curHeight 的最左节点的值。在深度优先搜索时&a…