PyTorch DataLoader整理函数详解【collate_fn】

news2024/9/28 1:20:40

DataLoader 是 PyTorch 中最常用的类之一。 而且,它是你首先学习的内容之一。 该类有很多参数,但最有可能的是,你将使用其中的大约三个参数(dataset、shuffle 和 batch_size)。 今天我想解释一下 collate_fn 的含义—根据我的经验,我发现它让初学者感到困惑。 我们将简要探讨 PyTorch 如何创建批数据,并了解如何根据需要修改默认行为。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 

1、批创建流程

每个深度学习课程中最重要的信息之一是我们批量执行训练/推理。 大多数时候,一个批次只是一些堆叠的数据样本。 但在某些情况下,我们想修改它的创建方式。

首先,让我们研究一下默认情况下会发生什么。 假设我们有以下玩具数据集。 它包含四个示例,每个示例三个功能。

import torch
from torch.utils.data import DataLoader
import numpy as np

data = np.array([
    [0.1, 7.4, 0],
    [-0.2, 5.3, 0],
    [0.2, 8.2, 1],
    [0.2, 7.7, 1]])
print(data)

如果我们向加载程序请求一个批次,我们将看到以下内容(请注意,我设置了 shuffle=False 以消除随机性):

loader = DataLoader(data, batch_size=2, shuffle=False)
batch = next(iter(loader))
print(batch)

# tensor([[ 0.1000,  7.4000,  0.0000],
#         [-0.2000,  5.3000,  0.0000]], dtype=torch.float64)

结果毫不奇怪,但让我们正式描述一下已经做了什么:

  • 加载器从数据集中选择了 2 个样本。
  • 这些样本被转换为张量(2 个大小为 3 的样本)。
  • 创建并返回一个新的张量 (2x3)。

默认设置还允许我们使用字典。 让我们看一个例子:

from pprint import pprint
# now dataset is a list of dicts
dict_data = [
    {'x1': 0.1, 'x2': 7.4, 'y': 0},
    {'x1': -0.2, 'x2': 5.3, 'y': 0},
    {'x1': 0.2, 'x2': 8.2, 'y': 1},
    {'x1': 0.2, 'x2': 7.7, 'y': 10},
]
pprint(dict_data)
# [{'x1': 0.1, 'x2': 7.4, 'y': 0},
# {'x1': -0.2, 'x2': 5.3, 'y': 0},
# {'x1': 0.2, 'x2': 8.2, 'y': 1},
# {'x1': 0.2, 'x2': 7.7, 'y': 10}]

loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
pprint(batch)
# {'x1': tensor([ 0.1000, -0.2000], dtype=torch.float64),
#  'x2': tensor([7.4000, 5.3000], dtype=torch.float64),
#  'y': tensor([0, 0])}

加载器足够聪明,可以正确地从字典列表中重新打包数据。 当你的数据采用 JSONL 格式(我个人更喜欢这种格式而不是 CSV)时,此功能非常方便。

2、自定义collate函数

如果默认规则如此智能,为什么我们需要创建自定义collate规则呢? 默认设置有一个很大的限制——批数据必须处于同一维度。 假设我们有一个 NLP 任务,并且数据是分词后的文本。

# values are token indices but it does not matter - it can be any kind of variable-size data
nlp_data = [
    {'tokenized_input': [1, 4, 5, 9, 3, 2],
     'label':0},
    {'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2],
     'label':0},
    {'tokenized_input': [1, 30, 67, 117, 21, 15, 2],
     'label':1},
    {'tokenized_input': [1, 17, 2],
     'label':0},
]
loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))

上面的代码不会工作并引发错误:

/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     80         elem_size = len(next(it))
     81         if not all(len(elem) == elem_size for elem in it):
---> 82             raise RuntimeError('each element in list of batch should be of equal size')
     83         transposed = zip(*batch)
     84         return [default_collate(samples) for samples in transposed]

RuntimeError: each element in list of batch should be of equal size

错误消息表明不可能创建非矩形张量。 顺便说一句,可以看到触发错误的是 default_collate函数。

我们可以做什么? 有两种解决方案:

  • 将整个数据集填充到最长的样本。
  • 在批创建期间动态填充。

第一个解决方案可能看起来更简单—只需将所有样本扩展到最长的样本即可。 但有一个问题—我们会浪费内存和计算能力(它们在 GPU 上很昂贵!)来处理 padding,这并不影响结果。 如果我们的数据中有一些长序列,而且大多数序列都相对较短,那就尤其痛苦。 在这种情况下,我们主要是处理填充而不是数据!

如果我们将整个数据集填充到最长的序列,会浪费大量空间!

另一种方法是动态填充数据。 当选择该批的样本时,我们只将它们填充到最长的样本。 如果我们另外按长度对数据进行排序,则填充将是最小的。 如果有一些非常长的序列,它们只会影响它们的批次,而不是整个数据集。

好吧,但是如何实现呢? 只需创建一个自定义 collate_fn , 这很简单:

from torch.nn.utils.rnn import pad_sequence #(1)

def custom_collate(data): #(2)
    inputs = [torch.tensor(d['tokenized_input']) for d in data] #(3)
    labels = [d['label'] for d in data]

    inputs = pad_sequence(inputs, batch_first=True) #(4)
    labels = torch.tensor(labels) #(5)

    return { #(6)
        'tokenized_input': inputs,
        'label': labels
    }

loader = DataLoader(
  	nlp_data, 
    batch_size=2, 
    shuffle=False, 
    collate_fn=custom_collate
) #(7)

iter_loader = iter(loader)
batch1 = next(iter_loader)
pprint(batch1)
batch2 = next(iter_loader)
pprint(batch2)

# {'label': tensor([0, 0]),
#  'tokenized_input': tensor([
#   [  1,   4,   5,   9,   3,   2,   0,   0,   0],
#   [  1,   7,   3,  14,  48,   7,  23, 154,   2]
# ])}

# {'label': tensor([1, 0]),
#  'tokenized_input': tensor([
#   [  1,  30,  67, 117,  21,  15,   2],
#   [  1,  17,   2,   0,   0,   0,   0]])}

代码说明如下:

  • 我们使用 pad_sequence进行填充
  • Collate 函数要传入单个参数 - 样本列表。 在这种情况下,它将是一个字典列表,但它也可以是一个元组列表等——具体取决于数据集。
  • 当数据出现时,如果格式为“字典列表”,我们需要遍历它并为所有输入和标签创建一个单独的列表。 与此同时, tokenized_input 被转换为一维张量(它是一个整数列表)。
  • 执行填充。
  • 由于标签是整数列表,我们将其转换为张量。
  • 返回格式化的批次。
  • 在加载器中设置我们的自定义整理函数。

正如我们所看到的,批的格式与字典的默认排序规则相同。 我们清楚地看到填充量很小。

3、结束语

创建自定义整理函数可能不是最常见的任务,但你绝对需要知道如何去做。


原文链接:PyTorch collate_fn详解 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1223125.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【开源】基于JAVA的快递管理系统

项目编号: S 007 ,文末获取源码。 \color{red}{项目编号:S007,文末获取源码。} 项目编号:S007,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 数据中心模块2.2 快递类型模块2.3 快…

windows 10 更新永久关闭

1 winR 输入:services.msc 编辑: 关闭:

【快速解决】实验三 简单注册的实现《Android程序设计》实验报告

目录 前言 实验要求 实验三 简单注册的实现 实验目的: 实验内容: 实验提示: 无 三、遇到的问题总结(如果有问题,请总结。如果没问题请写“无”) 正文开始 第一步建立项目 第二步选择empty views a…

Linux-top命令解释

Linux-top命令解释 常用参数查看所有逻辑核的运行情况:1查看指定进程的情况:-p pid显示进程的完整命令:-c 面板指标解释第一行top第二行tasks第三行%Cpu第四行Mem第五行Swap第六行各进程监控PID:进程IDUSER:进程所有者…

什么是BT种子!磁力链接又是如何工作的?

目录 一.什么是BT?1.BT简介:1.1.BT是目前最热门的下载方式之一1.2.BT服务器是通过一种传销的方式来实现文件共享的 2.小知识:2.1.你知道吗BT下载和常规下载到底有哪些不同2.2.BT下载的灵魂:种子2.3.当下载结束后,如果未…

Redis维护缓存的方案选择

Redis中间件常常被用作缓存,而当使用了缓存的时候,缓存中数据的维护,往往是需要重点关注的,尤其是重点考虑的是数据一致性问题。以下是维护数据库缓存的一些常用方案。 1、先删除缓存,再更新数据库 导致数据不一致的…

vulnhub靶机Momentum

下载地址:https://download.vulnhub.com/momentum/Momentum.ova 主机发现 目标192.168.21.129 端口扫描 端口版本扫描 漏洞扫描 扫出来点目录简单看看 发现js里面有一点东西 这里面告诉了我们了web文件有id传值,而且有aes加密还有密钥 跟二没有啥区别&…

ai的潜力和中短期的未来预测

内容来源:rickawsb ​对于描述ai的潜力和中短期的未来预测,我认为到目前为止可能没有比这篇推文总结得更好的了。 我读了三次。 文章起源于一个用户感叹openai升级chatgpt后,支持pdf上传功能,直接让不少的靠这个功能吃饭的创业公…

FISCOBCOS入门(十)Truffle测试helloworld智能合约

本文带你从零开始搭建truffle以及编写迁移脚本和测试文件,并对测试文件的代码进行解释,让你更深入的理解truffle测试智能合约的原理,制作不易,望一键三连 在windos终端内安装truffle npm install -g truffle 安装truffle时可能出现网络报错,多试几次即可 truffle --vers…

CTFhub-RCE-综合过滤练习

%0a、%0d、%0D%0A burp 抓包 修改请求为 POST /?127.0.0.1%0als 列出当前目录 返回包 http://challenge-135e46015a30567b.sandbox.ctfhub.com:10800/?ip127.0.0.1%0acd%09*here%0ac%27a%27t%09* _311632412323588.php

YOLOv8独家改进: Inner-IoU基于辅助边框的IoU损失,高效结合 GIoU, DIoU, CIoU,SIoU 等 | 2023.11

💡💡💡本文独家改进:Inner-IoU引入尺度因子 ratio 控制辅助边框的尺度大小用于计算损失,并与现有的基于 IoU ( GIoU, DIoU, CIoU,SIoU )损失进行有效结合 推荐指数:5颗星 新颖指数:5颗星 💡💡💡Yolov8魔术师,独家首发创新(原创),适用于Yolov5…

苍穹外卖--员工分页查询

请求参数封装: Data public class EmployeePageQueryDTO implements Serializable {//员工姓名private String name;//页码private int page;//每页显示记录数private int pageSize;}请求结果封装: public class PageResult implements Serializable {…

完整版解答!2023年数维杯国际大学生数学建模挑战赛B题

B题完整版全部5问,问题解答、代码,完整论文、模型的建立和求解、各种图表代码已更新! 大家好,目前已完成2023数维杯国际赛B题全部5问的代码和完整论文已更新,部分展示如下: 部分解答图表 问题分析 B题前三…

c++中的String

文章目录 String定义对象的方式成员函数operatorbegin/endsizecapacityreserversizeoperator/append/push_backoperator[]/at String String是一个类模版,可以定义一个字符/字符串对象。 字符顺序表 定义对象的方式 定义方式有很多重要的就这几种 string s1;stri…

屏蔽bing搜索框的今日热点

中国版的Bing简直比百度还恶心了,“今日热点”要是在搜索设置里关闭了就没法提供搜索建议了,不关吧看着又烦人,就像下图这样。另外还有右上角的下载bing app和Rewards图标也闲着没啥用,Rewards图标还老有小红点,真受不…

【C++】类和对象(5)--运算符重载

目录 一 概念 二 运算符重载的实现 三 关于时间的所有运算符重载 四 默认赋值运算符 五 const取地址操作符重载 一 概念 C为了增强代码的可读性引入了运算符重载,运算符重载是具有特殊函数名的函数,也具有其返回值类型,函数名字以及参数…

实验六:Android的网络编程基础

实验六:Android 的网络编程基础 6.1 实验目的 本次实验的目的是让大家熟悉 Android 开发中的如何获取天气预报,包括了 解和熟悉 WebView、WebService 使用、网络编程事件处理等内容。 6.2 实验要求 熟悉和掌握 WebView 使用 了解 Android 的网络编程…

cs与msf联动

实验环境 cs4.4(4.5版本不知道为啥实现不了) cs服务器与msf在同一台vps上 本地win7虚拟机 cs派生会话给msf 首先cs正常上线win7,这就不多说了,然后说如何将会话派生给msf cs准备 选择Foreign,这里可以选HTTP,也可以选HTTPS…

LLM大模型权重量化实战

大型语言模型 (LLM) 以其广泛的计算要求而闻名。 通常,模型的大小是通过将参数数量(大小)乘以这些值的精度(数据类型)来计算的。 然而,为了节省内存,可以通过称为量化的过程使用较低精度的数据类…

庖丁解牛:NIO核心概念与机制详解

文章目录 Pre输入/输出Why NIO流与块的比较通道和缓冲区概述什么是缓冲区?缓冲区类型什么是通道?通道类型 NIO 中的读和写概述Demo : 从文件中读取1. 从FileInputStream中获取Channel2. 创建ByteBuffer缓冲区3. 将数据从Channle读取到Buffer中 Demo : 写…