2023.7.4 Dataloader切分

news2024/11/24 12:41:45

一、

如果文件夹路径是 path/to/folder with spaces/,使用以下方式输入

path/to/folder\ with\ spaces/

或者使用引号包裹路径

"path/to/folder with spaces/"

这样可以确保命令行正确解析文件夹路径,并将空格作为路径的一部分进行处理。

二、

DataLoader(object):

class DataLoader(object):
    def __next__(self):
        if self.num_workers == 0:  
            indices = next(self.sample_iter)  # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch
            
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
                                               

假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成.
下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。
再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。
在这里插入图片描述
pytorch源码描述:

sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any Iterable with __len__
implemented. If specified, :attr:shuffle must not be specified.
对应代码:可见:所有的exception都可以从源码中找到

if sampler is not None and shuffle:
           raise ValueError('sampler option is mutually exclusive with '
                           'shuffle')

samplerbatch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。
在这里插入图片描述

若shuffle=True,则sampler=RandomSampler(dataset)
若shuffle=False,则sampler=SequentialSampler(dataset)Sequential Sampler产生的索引是顺序索引
所以当

data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)
# ran_sampler = sampler.RandomSampler(data_source=data)
for index in seq_sampler: # sampler就是产生index的
    print("index: {}, data: {}".format(str(index), str(data[index])))
# SequentialSampler
index: 0, data: 17
index: 1, data: 22
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8
# RandomSampler
index: 0, data: 17
index: 2, data: 3
index: 3, data: 41
index: 4, data: 8
index: 1, data: 22

SubsetRandomSampler:

SubsetRandomSampler(Sampler):
    Samples elements randomly from a given list of indices, without replacement.
    Arguments:
        indices (sequence): a sequence of indices
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))


    def __len__(self):
        return len(self.indices)

这个采样器常见的使用场景是将训练集划分成训练集和验证集,示例如下:(记住)

n_train = len(train_dataset)
split = n_train // 3
indices = random.shuffle(list(range(n_train)))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
train_loader = DataLoader(..., sampler=train_sampler, ...)
valid_loader = DataLoader(..., sampler=valid_sampler, ...)

subsetRandomSampler应该用于训练集、测试集和验证集的划分,下面将data划分为train和val两个部分,iter()返回的的不是索引,而是索引对应的数据:
iter()返回的并不是随机数序列,而是通过随机数序列作为indices的索引,进而返回打乱的数据本身

sub_sampler_train = sampler.SubsetRandomSampler(indices=data[0:2])
sub_sampler_val = sampler.SubsetRandomSampler(indices=data[2:])
# Train的输出:
index: 17
index: 22
######
# Val的输出
index: 8
index: 41
index: 3

其实是没有SubSequentialSampler的:
ImportError: cannot import name 'SubsetSequentialSampler' from 'torch.utils.data.sampler'

最终代码:

用上面乱七八糟的:

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转为Tensor类型
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 数据归一化
])

# 加载完整的训练集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

#写一个子类SubsetSequentialSampler
train_index_sampler = SubsetRandomSampler
  class SubsetSequentialSampler(SubsetRandomSampler):
      def __iter__(self):
          return (self.indices[i] for i in torch.arange(len(self.indices)).int())
val_index_sampler = SubsetSequentialSampler

train_n = len(dataset)
indices = list(range(train_n))
val_size =  int(train_n * 0.2)
train_size = train_n - val_size
np.random.seed(10)
np.random.shuffle(indices)
train_idx, val_idx = indices[val_size:], indices[:val_size]
train_sampler = SubsetRandoomSampler(train_idx)
val_sampler = val_index_sampler(val_idx)

train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,num_workers=num_workers,sampler=train_sampler)
val_dataloader = torch.utils.data.DataLoader(validset, batch_size=batch_size,num_workers=num_workers,sampler=val_sampler)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,num_workers=num_workers)

干脆不用!!太麻烦了!!还是random_split爽:

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转为Tensor类型
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 数据归一化
])

# 加载完整的训练集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# 将训练集分割为训练集和验证集
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
trainset, validset = random_split(dataset, [train_size, valid_size])

# 创建数据加载器
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
validloader = DataLoader(validset, batch_size=64, shuffle=False, num_workers=2)

# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/721525.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ADB自动化测试框架

一、介绍 adb的全称为Android Debug Bridge,就是起到调试桥的作用,利用adb工具的前提是在手机上打开usb调试,然后通过数据线连接电脑。在电脑上使用命令模式来操作手机:重启、进入recovery、进入fastboot、推送文件功能等。简单来…

Intellij IDEA 初学入门图文教程(八) —— IDEA 在提交代码时 Performing Code Analysis 卡死

在使用 IDEA 开发过程中,提交代码时常常会在碰到代码中的 JS 文件时卡死,进度框上显示 Performing Code Analysis,如图: 原因是 IDEA 工具默认提交代码时,分析代码功能是打开的,需要通过配置关闭下就可以了…

Linux高性能网络编程:TCP底层的收发过程

今天探索高性能网络编程,但是我觉得在谈系统API之前可以先讲一些Linux底层的收发包过程,如下这是一个简单的socket编程代码: int main() {... fd socket(AF_INET, SOCKET_STREAM, 0);bind(fd, ...);listen(fd, ...);// 如何建立连接...afd …

冒泡排序法(优化与实例演示)

冒泡排序法 冒泡排序法基本介绍 冒泡排序是一种简单而经典的排序算法,它的原理是通过不断比较相邻元素的大小并交换位置,将较大(或较小)的元素逐渐“冒泡”到数组的末尾。这个过程持续进行多轮,直到整个数组按照顺序…

【Zabbix 6.0 监控系统安装和部署】

目录 一、Zabbix 介绍1、zabbix 是什么?2、zabbix 监控原理(重点)3、Zabbix 6.0 新特性4、Zabbix 6.0 功能组件1、Zabbix Server2、数据库3、Web 界面4、Zabbix Agent5、Zabbix Proxy6、Java Gateway 二、Zabbix 6.0 部署1、部署 zabbix 服务…

idea goland 插件 struct to struct

go-struct-to-struct idea goland 插件。实现自动生成 struct 间 转换代码。 https://plugins.jetbrains.com/plugin/22196-struct-to-struct/ IntelliJ plugin that Automatically generate two struct transformations through function declarations Usage define func …

【怎么实现多组输入之EOF】

C语言怎么实现多组输入之EOF C语言之EOF介绍1、什么是EOF?2、EOF的用法3、EOF的扩展3.1、scanf返回值之EOF3.2、scanf函数的返回值有以下几种情况 4、如何是实现多组输入?4.1、多组输入---- 常规写法例程14.2、多组输入---- 实现多组输入的打印例程24.3、…

不想被卷的程序员们,应该学什么?

我真的好像感慨一下,这个世界真的给计算机应届生留活路了吗? 看着周围的同学,打算搞前端、JAVA、C、C的,一个两个去跑去应聘。你以为是00后整治职场? 真相是主打一个卑微:现阶段以学习为主(工…

探寻日本区块链游戏的未来潜力

日本的区块链游戏 日本是全球范围内游戏市场人均利润最高的国家之一。其中,《My Crypto Heroes》的首次公售金额达到了 16,000 ETH。 关键要点: 日本具有强大的游戏基础,使其成为加密游戏发展的理想地区。 日本流行的加密货币游戏包括《My…

Python中jsonpath库使用,及与xpath语法区别

jsonpath库使用 pip install jsonpath 基本语法 JSONPath语法元素和对应XPath元素的对比

Work20230705

//main.c #include "uart4.h" extern void printf(const char *fmt, ...); void delay_ms(int ms) {int i,j;for(i 0; i < ms;i)for (j 0; j < 1800; j); }int main() {while(1){//将获取到的字符1发送到终端//hal_put_char(hal_get_char()1);hal_put_string…

POSTGRESQL SQL 执行用 IN 还是 EXISTS 还是 ANY

开头还是介绍一下群&#xff0c;如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请联系 liuaustin3 &#xff0c;在新加的朋友会分到3群&#xff08;共…

【后端面经-计算机基础】HTTP和TCP的区别

【后端面经-计算机基础】HTTP和TCP的区别 文章目录 【后端面经-计算机基础】HTTP和TCP的区别1. OSI七层模型和相关协议2. TCP协议2.1 特点&#xff1a;2.2 报文格式2.3 三次握手和四次挥手 3. HTTP协议3.1 特点3.2 报文格式3.2 https和http 4. HTTP vs TCP5. 面试模拟参考资料 …

全网最牛,python接口自动化测试-接口sign签名(实战撸码)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 一般公司对外的接…

groupkfold 报错:raise keyerror(f“{not foud} not in index“)

【1】使用groupkfold 的时候出现以上报错&#xff1a;索引错误&#xff0c;groups的索引和x y 的不对应 【2】源代码&#xff1a; 【3】进行修改&#xff1a; 可以成功索引&#xff01;&#xff01;&#xff01;

tomcat下上传html

html 最基本结构服务器xshelltomcat 下载是否可以访问到服务器上传html html 最基本结构 .html 后缀名 <!DOCTYPE HTML> <html><head><meta charset"utf-8"> <title>2306</title></head><body>大家好&#xff01;…

C++图形开发(7):能进行抛物线运动且触墙能反弹的小球

今天来实现一下触墙能反弹的小球、 我们之前所实现的都只是小球的上下&#xff0c;也就是y轴方向的运动&#xff08;详见&#xff1a;C图形开发&#xff08;6&#xff09;&#xff1a;落下后能弹起的小球&#xff09;&#xff0c;那么要使小球能够呈抛物线状运动&#xff0c;我…

Failed to start connector [Connector[HTTP/1.1-8080]]

1、解决Web server failed to start. Port 8080 was already in use 2、SpringBoot启动报错:“Error starting ApplicationContext. To display the conditions report re-run your application with ‘debug’ enabled.” 3、Failed to start end point associated with Proto…

015-从零搭建微服务-远程调用(一)

写在最前 如果这个项目让你有所收获&#xff0c;记得 Star 关注哦&#xff0c;这对我是非常不错的鼓励与支持。 源码地址&#xff08;后端&#xff09;&#xff1a;https://gitee.com/csps/mingyue 源码地址&#xff08;前端&#xff09;&#xff1a;https://gitee.com/csps…

如何利用Spine制作简单的2D骨骼动画

在2D游戏中&#xff0c;我们经常看到各种各样的角色动画。动画能给游戏带来生机和灵气。创作一段美妙的动画&#xff0c;不仅需要强大的软件工具&#xff0c;更需要一套完善的工作流程。 Spine就是一款针对游戏开发的2D骨骼动画编辑工具。Spine 可以提供更高效和简洁 的工作流…