第六章 番外篇:webdataset

news2024/12/23 17:42:16

参考教程:
https://github.com/pytorch/pytorch/issues/38419
https://zhuanlan.zhihu.com/p/412772439
https://webdataset.github.io/webdataset/gettingstarted/


文章目录

  • 背景
  • WebDataset
  • webdataset的生成
  • webdataset的加载
  • 示例代码

背景

训练数据通常是以个体的方式存储的,就像我们在第一章下载并处理成png格式后的cifar10数据,它以’xxx.png’的文件形式存放在一个一个独立的空间中。
随着数据集变得越来越大,这样的存放形式就不是那么高效和便捷。在进行模型训练时,也会因为数据的IO瓶颈拖慢训练的速度。
在使用Dataset中的数据时,我们的__getitem__(self, idx)函数会根据数据的index检索数据。在训练时,我们一般都会使用shuffle = True来完成数据的随机读取,这样索引的index也是无效的,当图片数据直接存放在系统上时,对文件的访问需要花费大量的代价。
这个问题可以使用sequential storage formats and sharding来解决。就像tensorflow中使用的TFRecord格式,它将训练集/测试集打包在一起使用,文件里存储的就是序列化的tf.Example。Pytorch是没有这种专属的数据存储格式的。

WebDataset

WebDataset提供了一种序列化存储大规模数据的方法,它将数据保存在tar包中,但是在使用时不需要对tar包进行解压。这种形式提供了高效的I/O,并且不管是在本地还是云端数据上都表现很不错。

webdataset的生成

webdataset是一个tar文件,所以你直接使用tar命令就可以进行文件的生成。

tar --sort=name -cf dataset.tar dataset/

我们也可以使用python调用webdataset的包,来进行文件的写入操作。
以下面的代码为例,下方的代码想要将现有的MNIST数据存放到’mnist.tar’文件中,因此它按照顺序将数据一个一个多写入了文件里。

dataset = torchvision.datasets.MNIST(root="./temp", download=True) # 获得MNIST数据
sink = wds.TarWriter("mnist.tar") # 使用TarWriter,准备将数据写入mnist.tar
for index, (input, output) in enumerate(dataset):
    if index%1000==0:
        print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) # 每写入1000个数据,输出一些状态
    sink.write({
        "__key__": "sample%06d" % index, # 当前的数据的index
        "input.pyd": input, # 数据的input
        "output.pyd": output, # 数据的target
    })
sink.close() # 关闭当前文件。

这里的sink_write写入了是一个dict,其中’key’这一项决定了你想保存的数据的前缀名,’input.pyd’是你的input的数据的后缀,它同时也决定了你的数据存放的格式。
比如说这里使用的’pyd’,就是我们之前说过的pickle格式,它可以保证数据的完整性,以不压缩的形式存储数据,缺点是不能被其它的语言读取。
在你明确知道数据的类型的情况下,你也可以使用别的格式来存放数据,比如说对于图片,你可以使用‘ppm’,‘png’,'jpg’等格式,对于图片的标签,已知数据标签是整数的形式时,可以使用’cls’格式。

webdataset的加载

对于一个存入tar的webdataset的数据,你可以通过它的url对它进行读取,这个url可以是云端地址,也可以是本地路径。

import webdataset as wds
dataset = wds.WebDataset(url)

我们在讲数据存入tar时,writer根据我们定义的数据格式对数据进行了encode,所以我们直接读取到的数据是还没有decode的数据。
在教程中给了这样一个例子。
在这里插入图片描述
直接获取到的数据格式是bytes的格式。
你可以数据进行一些处理,webdataset提供一种链式的数据处理方法,比如上面的数据,你就可以使用下面的方法处理。

dataset = (
    wds.WebDataset(url)
    .shuffle(100)
    .decode("rgb")
    .to_tuple("jpg;png", "json")
)

这里的decode传入的’rgb’属于headler,webdataset提供了一些自带的imageheadler。帮助使用者进行数据类型转换。imagespecs = { "l8": ("numpy", "uint8", "l"), "rgb8": ("numpy", "uint8", "rgb"), "rgba8": ("numpy", "uint8", "rgba"), "l": ("numpy", "float", "l"), "rgb": ("numpy", "float", "rgb"), "rgba": ("numpy", "float", "rgba"), "torchl8": ("torch", "uint8", "l"), "torchrgb8": ("torch", "uint8", "rgb"), "torchrgba8": ("torch", "uint8", "rgba"), "torchl": ("torch", "float", "l"), "torchrgb": ("torch", "float", "rgb"), "torch": ("torch", "float", "rgb"), "torchrgba": ("torch", "float", "rgba"), "pill": ("pil", None, "l"), "pil": ("pil", None, "rgb"), "pilrgb": ("pil", None, "rgb"), "pilrgba": ("pil", None, "rgba"), }
webdataset提供了多种数据的decode方式的示例,你也可以自定义decode的方法。具体的源码可以查看https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py。

decoders = {
    "txt": lambda data: data.decode("utf-8"),
    "text": lambda data: data.decode("utf-8"),
    "transcript": lambda data: data.decode("utf-8"),
    "cls": lambda data: int(data),
    "cls2": lambda data: int(data),
    "index": lambda data: int(data),
    "inx": lambda data: int(data),
    "id": lambda data: int(data),
    "json": lambda data: json.loads(data),
    "jsn": lambda data: json.loads(data),
    "pyd": lambda data: pickle.loads(data),
    "pickle": lambda data: pickle.loads(data),
    "pth": lambda data: torch_loads(data),
    "ten": tenbin_loads,
    "tb": tenbin_loads,
    "mp": msgpack_loads,
    "msg": msgpack_loads,
    "npy": npy_loads,
    "npz": lambda data: np.load(io.BytesIO(data)),
    "cbor": cbor_loads,
}

如果是想要自己定义decode的方法,可以使用以下类似的方法。以下的方法中定义了my_decoder方法,这方法会判断dataset中sample的key是否为jpg,如果不是则忽略,是的话才会返回结果。要注意这里直接获得的数据类型都是bytes,你可以使用类似于**imageio.imread(io.BytesIO(value))**处理数据,将它转为图片。

def my_decoder(key, value):
        if not key.endswith(".jpg"):
            return None
        assert isinstance(value, bytes)
        return value

dataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)

示例代码

最后给出一个简单的webdataset多进程存储的方法,这里使用的dataset中返回sample是dict形式,最后以pickle的形式存放到指定数量的tar中。

import multiprocessing as mp
import webdataset as wds
import pickle
import os

def write_samples(dataset, tar_index, sample_index,save_dir):
    for t_idx, s_idx in zip(tar_index, sample_index):
        fname = os.path.join(save_dir,str(t_idx)+'.tar')
        stream = wds.TarWriter(fname)
        for idx in s_idx:
            data = dataset[idx]
            sample = {}
            sample['__key__'] = "sample%06d" % idx
            for key, value in data.items():
                sample[key +'.pyd'] = value
            stream.write(sample)
        stream.close()

def dataset2tar(dataset, save_dir,num_tars, num_workers):
    num_len = len(dataset)
    data_index = [i for i in range(num_len)]
    samples = [data_index[i::num_tars] for i in range(num_tars)]
    tar_index = list(range(num_tars))
    jobs = []
    for i in range(num_workers):
        job = mp.Process(target = write_samples,args=(dataset,tar_index[i::num_workers],samples[i::num_workers],save_dir))
        job.start()
        jobs.append(job)
   
    for job in jobs:
        job.join()
    
def pyd_decoder(key, data):
    if not key.endswith(".pyd"):
        return None
    result = pickle.loads(data)
    return result

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/653035.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

霹雳吧啦 目标检测 学习笔记

霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频 目标检测篇github地址;GitHub - WZMIAOMIAO/deep-learning-for-image-processing: deep learning for image processing including classification and object-detection etc. 数据集 实例分割vs语义分割&a…

【强烈推荐】 十多款2023年必备国内外王炸级AI工具 (免费 精品 好用) 让你秒变神一样的装逼佬感受10倍生产力 (7) AI语言模型

🚀 个人主页 极客小俊 ✍🏻 作者简介:web开发者、设计师、技术分享博主 🐋 希望大家多多支持一下, 我们一起进步!😄 🏅 如果文章对你有帮助的话,欢迎评论 💬点赞&#x1…

云安全的第一站:CSPM

在企业数字化转型和云计算技术的加持下,企业上云趋势势不可挡。与此同时,数据量加大,网络攻击日趋频繁,对企业来说,包括云计算安全在内的网络安全部署的重要性日益显现。 在Gartner2022年CIO技术执行官问卷调查中&…

chatgpt赋能python:Python怎么绕过短信验证

Python怎么绕过短信验证 短信验证以及其他形式的验证码已经成为了许多网站和应用程序保护用户隐私的常见方式。然而,对于某些特定的情况,用户可能需要绕过这些验证码,例如自动化测试或者爬取数据。那么,在Python中,我…

安装Hive

安装Hive 准备 安装Java环境:Hive需要Java环境支持,所以需要先安装Java。安装文档:http://t.csdn.cn/deBJu 安装MySQL数据库。http://t.csdn.cn/d24pN 下载Hive 下载Hive的二进制文件。 链接:https://pan.baidu.com/s/1fdg7…

管理类联考——英语二——技巧篇——写作——书信作文——经典方法论

第一节 书信作文谋篇布局 考研英语从2005年开始考查书信作文,迄今为止共考查过几十次。书信作文考查的信件种类繁多,其中建议信是考查最为频繁的信件类型。从考查内容来看,校园学习生活、职业发展、民生热点成为重点考查对象,这一…

hadoop 相关环境搭建

21.Windows下安装Hadoop; Hive MySQL版_hadoop hive windows安装_学无止境的大象的博客-CSDN博客 https://www.cnblogs.com/liugp/p/16244600.html 备注。因为beeline一直报错,最有一怒之下把hive的lib下所有jar都拷贝到hadoop的share\hadoop\common\lib…

2023 年 5 大机器人趋势

原创 | 文 BFT机器人 国际机器人联合会报告 法兰克福,2023 年 2 月 16 日——全球操作机器人的存量创下约 350 万台的新纪录——安装价值估计达到 157 亿美元。国际机器人联合会分析了 2023 年影响机器人技术和自动化的 5 大趋势。 2023 年 5 大机器人趋势 © 国…

2000-2021年全国1km分辨率的逐日PM10栅格数据

空气质量数据是在我们日常研究中经常使用的数据!之前我们分享了来自于Zendo平台的1km分辨率的PM2.5栅格数据(可查看之前的文章获悉详情): 2000-2021年全国1km分辨率的逐日PM2.5栅格数据 2000-2021年全国1km分辨率的逐月PM2.5栅格…

双功能螯合剂Me-Tetrazine PEG7 NOTA,应用于生物和材料科学的研究中

文章关键词:双功能螯合剂,大环化合物 MeTz-PEG7-NOTA,NOTA PEG7 Me-Tetrazine,甲基四嗪-PEG7-NOTA (文章编辑来源于:西安凯新生物科技有限公司小编WMJ)​ 一、Product structure:…

组合逻辑毛刺消除

目录 组合逻辑毛刺消除 1、简介 2、实验任务 3、程序设计 1、组合逻辑输出加寄存器 2、信号同步法 (1)信号延时同步法 (2)状态机控制 3、格雷码计数器 4、仿真验证 组合逻辑毛刺消除 信号在 IC/FPGA 器件中通过逻辑单元…

管理类联考——英语——翻译篇——新题型——经典方法论

第一节 英语(一)翻译 根据考试大纲,考研英语(一R翻译部分主要考查考生准确理解概念或结构较复杂的英语文字材料的能力。具体考查方式是要求考生阅读一篇约400词的文章,并将其中5个画线部分(约150词)译成汉语,要求译文准确、完整、通顺。 可以看出&#…

js数组高阶函数——filter()方法

js数组高阶函数——filter方法 filter()方法⭐⭐⭐例1⭐⭐⭐例2⭐⭐⭐例3⭐⭐⭐例4⭐⭐⭐例5 filter()方法 ⭐一般来说,filter() 方法用于过滤数组中的元素,并返回一个新数组。 语法: array.f…

Python多线程编程详解

概要 进程(process)指的是正在运行的程序的实例,当我们执行某个程序时,进程就被操作系统创建了。而线程(thread)则包含于进程之中,是操作系统能够进行运算调度的最小单元,多个线程可…

【AntDB数据库】AntDB数据库告警管理

告警历史 功能概述 数据库系统的主机、单节点集群的被监测指标达到告警阀值时,AMOPS就会产生告警并展示在告警分类页面上。 告警分类页面提供告警搜索查看功能,用户可以指定监控项、集群、事件级别、时间范围和告警对象对告警进行搜索。 查询的告警数…

Android12之执行adb disable-verity后android无法启动(一百五十六)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

AotucCrawler 快速爬取图片

AotucCrawler 快速爬取图片 今天介绍一款自动化爬取图片项目。 GitHub: GitHub - YoongiKim/AutoCrawler: Google, Naver multiprocess image web crawler (Selenium) Google, Naver multiprocess image web crawler (Selenium) 关键字 爬虫网站:Google、Naver &…

【深度学习】2-3 神经网络-输出层设计

前馈神经网络(Feedforward Neural Network),之前介绍的单层感知机、多层感知机等都属于前馈神经网络,它之所以称为前馈(Feedforward),或许与其信息往前流有关:数据从输入开始,流过中间计算过程,最后达到输出…

springboot-内置Tomcat

一、springboot的特性之一 基于springboot的特性 自动装配Configuretion 注解 二、springboot内置Tomcat步骤 直接看SpringApplication方法的代码块 总纲: 1、在SpringApplication.run 初始化了一个上下文ConfigurableApplicationContext configurableApplica…

服务负载均衡Ribbon

服务负载均衡Ribbon Ribbon 介绍Ribbon 案例Ribbon 负载均衡策略Ribbon 负载均衡算法设置自定义负载均衡算法 Ribbon 介绍 Ribbon 是一个的客服端负载均衡工具,它是基于 Netflix Ribbon 实现的。它不像 Spring Cloud 服务注册中心、配置中心、API 网关那样独立部署…