理解Pytorch中的collate_fn函数

news2024/11/15 10:39:43

PyTorch中的DataLoader是最常用的类之一,这个类有很多参数(14 个),但大多数情况下,你可能只会使用其中的三个:dataset、shuffle 和 batch_size。其中collate_fn是比较少用的函数,这对初学者来说是一个容易混淆的概念。下面将简要探讨 PyTorch 如何创建批次,并了解如何根据我们的需求修改其默认行为。

批处理

首先,先创建一个数据,用data变量标视

import torch
from torch.utils.data import DataLoader
import numpy as np

data = np.array([
    [0.1, 7.4, 0],
    [-0.2, 5.3, 0],
    [0.2, 8.2, 1],
    [0.2, 7.7, 1]])
print(data)

如果我们加载一个批次出来( shuffle=False以消除随机性):

loader = DataLoader(data, batch_size=2, shuffle=False)
batch = next(iter(loader))
print(batch)

# tensor([[ 0.1000,  7.4000,  0.0000],
#         [-0.2000,  5.3000,  0.0000]], dtype=torch.float64)

结果符合预期的,我们来解释一下已经做了什么:

  • 加载器从数据集中选择了 2 个样本。
  • 这些样本被转换为张量(2 个大小为 3 的样本)。
  • 创建并返回一个新的张量 (2x3)。

默认设置还允许我们使用字典。 让我们看一个例子:

from pprint import pprint
# now dataset is a list of dicts
dict_data = [
    {'x1': 0.1, 'x2': 7.4, 'y': 0},
    {'x1': -0.2, 'x2': 5.3, 'y': 0},
    {'x1': 0.2, 'x2': 8.2, 'y': 1},
    {'x1': 0.2, 'x2': 7.7, 'y': 10},
]
pprint(dict_data)
# [{'x1': 0.1, 'x2': 7.4, 'y': 0},
# {'x1': -0.2, 'x2': 5.3, 'y': 0},
# {'x1': 0.2, 'x2': 8.2, 'y': 1},
# {'x1': 0.2, 'x2': 7.7, 'y': 10}]

loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
pprint(batch)
# {'x1': tensor([ 0.1000, -0.2000], dtype=torch.float64),
#  'x2': tensor([7.4000, 5.3000], dtype=torch.float64),
#  'y': tensor([0, 0])}

Dataloader简单易用,可以正确地从字典列表中重新打包数据。 当你的数据采用 JSON格式时,此功能非常方便。

自定义collate函数

Dataloader默认设置能覆盖大部分场景的数据读取,但默认设置有一个很大的限制——批数据必须处于同一维度。 假设我们有一个 NLP 任务,并且数据是分词后的文本。

# values are token indices but it does not matter - it can be any kind of variable-size data
nlp_data = [
    {'tokenized_input': [1, 4, 5, 9, 3, 2],
     'label':0},
    {'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2],
     'label':0},
    {'tokenized_input': [1, 30, 67, 117, 21, 15, 2],
     'label':1},
    {'tokenized_input': [1, 17, 2],
     'label':0},
]
loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))

这样强行去压成一个batch存储,会引发错误:

/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     80         elem_size = len(next(it))
     81         if not all(len(elem) == elem_size for elem in it):
---> 82             raise RuntimeError('each element in list of batch should be of equal size')
     83         transposed = zip(*batch)
     84         return [default_collate(samples) for samples in transposed]

RuntimeError: each element in list of batch should be of equal size

报错信息显示不能创建非矩形张量。顺便说一句,可以看到触发错误的是 default_collate函数。

如何修改? 有两种解决方案:

  • 将整个数据集填充到最长的样本。
  • 在Batch data创建期间进行动态填充。

第一个方法最简单,但是非常耗内存,极端条件下,我们有1万条数据,其中9999条数据长度是10,而只有1条数据长度是1000,那么所有的数据都需要pad数值,使得长度填充到1000。这样有99%的内存占用都是无意义的。

pad-max

另一种方法是动态填充数据。 当选择这一个batch的样本时,我们只需要将数据填充到这一个batch最长的样本长度即可。另外,将数据按照长度进行排序,则填充的数量将是最小的。 如果有一些非常长的序列,它们只会影响它们的这一个batch的效率,而不是整个数据集。

pad-batch

具体如何实现呢?

from torch.nn.utils.rnn import pad_sequence #(1)

def custom_collate(data): #(2)
    inputs = [torch.tensor(d['tokenized_input']) for d in data] #(3)
    labels = [d['label'] for d in data]

    inputs = pad_sequence(inputs, batch_first=True) #(4)
    labels = torch.tensor(labels) #(5)

    return { #(6)
        'tokenized_input': inputs,
        'label': labels
    }

loader = DataLoader(
  	nlp_data, 
    batch_size=2, 
    shuffle=False, 
    collate_fn=custom_collate
) #(7)

iter_loader = iter(loader)
batch1 = next(iter_loader)
pprint(batch1)
batch2 = next(iter_loader)
pprint(batch2)

# {'label': tensor([0, 0]),
#  'tokenized_input': tensor([
#   [  1,   4,   5,   9,   3,   2,   0,   0,   0],
#   [  1,   7,   3,  14,  48,   7,  23, 154,   2]
# ])}

# {'label': tensor([1, 0]),
#  'tokenized_input': tensor([
#   [  1,  30,  67, 117,  21,  15,   2],
#   [  1,  17,   2,   0,   0,   0,   0]])}

代码功能如下:

  • 我们使用 pad_sequence进行填充
  • custom_collate作为参数传递给DataLoader
  • 在运行时对inputs进行动态填充

总结

collate_fn是一个很少用的函数,但对提升训练效率有很大的帮助。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2046287.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年国家数据局第一批20个“数据要素×”典型案例解析

国家数据局首批20个“数据要素”典型案例解析 1、简介1.1 背景简介1.2 典型案例分类 2、案例解析2.1 工业制造领域案例1:数据要素驱动适应多式联运需求的运输装备协同制造案例2:打造工业数据空间 赋能产业链上下游发展 2.2 现代农业领域案例3&#xff1a…

07一阶电路和二阶电路的时域分析

一阶电路和二阶电路的时域分析 时域分析、频域分析、复频域分析本应该在信号与系统,或者数字信号处理这一章节里面进行处理的。 但在电路理论中也有这些知识,那就要好好掌握一下,打个底。详细细致的部分放到信号与系统里面去掌握

Spring Web MVC入门(中)

1. 请求 访问不同的路径, 就是发送不同的请求. 在发送请求时, 可能会带⼀些参数, 所以学习Spring的请求, 主要 是学习如何传递参数到后端以及后端如何接收. 传递参数, 咱们主要是使⽤浏览器和Postman来模拟; 1.1 传递单个参数 接收单个参数,在Spring MV…

七段S型加减速算法原理及其多种形状仿真

1、基本7段S型: 七段S型加减速的位置、速度、加速度、加加速度曲线如下图所示。 加加速度: 加速度: 速度: 位置: 以上是7段S型加减速的最基本公式,在实际应用中还需要考虑到起始和终止速度大于匀速速度的情…

【JavaSE】解读Java中的toString方法

前言: 在Java中,toString方法来自java.lang.Object 类,然后所有对象都继承该Object 类。默认情况下,它的作用是返回对象的字符串表示形式。在实际开发中,重写 toString() 方法可以帮助我们以更易读的形式输出对象信息&…

Verilog基础:模块端口(port)定义的语法(2001标准)

相关阅读 Verilog基础https://blog.csdn.net/weixin_45791458/category_12263729.html?spm1001.2014.3001.5482 Verilog中的端口定义有两种风格,一种是Verilog Standard 1995风格,一种是Verilog Standard 2001风格,本文将对Verilog Standar…

C语言基础11指针

指针的引入 为函数修改实参提供支持。 为动态内存管理提供支持。 为动态数据结构提供支持。 为内存访问提供另一种途径。 指针概述 内存地址: 系统为了内存管理的方便,将内存划分为一个个的内存单元( 1 个内存单元占 1 个字节&#xff09…

自动控制——状态观测器

自动控制——状态观测器 引言 在自动控制系统中,准确地了解系统的状态对实现高性能控制至关重要。然而,在许多实际应用中,我们无法直接测量系统的所有状态变量。这时,状态观测器(State Observer)就发挥了…

【LeetCode面试150】——209长度最小的子数组

博客昵称:沈小农学编程 作者简介:一名在读硕士,定期更新相关算法面试题,欢迎关注小弟! PS:哈喽!各位CSDN的uu们,我是你的小弟沈小农,希望我的文章能帮助到你。欢迎大家在…

轻松上手MYSQL:精通正则表达式,数据匹配不再难!

🌈 个人主页:danci_ 🔥 系列专栏:《设计模式》《MYSQL》 💪🏻 制定明确可量化的目标,坚持默默的做事。 ✨欢迎加入探索MYSQL正则表达式函数之旅✨ 👋 大家好!文本学习…

【Cesium】Cesium图层请求完成的回调

有一个业务需要用到cesium图层请求完成的回调&#xff0c;翻了好久的文档终于给我找到&#x1f336;️。 是Cesium.ImageryProvider类的一个属性readyPromise 效果如下&#xff1a; Cesium图层请求完成的回调 完整代码如下&#xff1a; <html lang"en"><h…

PCDN业务推荐

神鸟云&蘑菇云最新业务推荐 &#x1f525;短Z业务-- 支持nat0~nat4 省内调度&#xff0c;晚高峰 跑量9成 配置要求: 线路&#xff1a;单条上行30M 硬件&#xff1a;32线程 64内存条 240G系统盘 1G:2T固态盘 单价&#xff1a;移动1900 电联2500 http://oss.download.…

Mysql(四)---增删查改(进阶)

文章目录 前言1.查询操作1.1.全列查询1.2.指定列查询1.3.列名为表达式查询1.4.查询中使用别名1.5.去重查询1.6.排序1.6.2.NULL 1.7.条件查询1.8.分页查询 2.修改3.删除 前言 上一篇博客&#xff0c;我们学习了一些主键的概念&#xff0c;并且分别创造了一些示例表&#xff0c;…

通过相机来获取图片

文章目录 1. 概念介绍2. 方法与细节2.1 实现方法2.2 具体细节 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何混合选择多个图片和视频文件"相关的内容&#xff0c;本章回中将介绍如何通过相机获取图片文件.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. …

Codeforces Round 966 (Div. 3)(A,B,C,D,E,F)

A. Primary Task 签到 void solve() {string s;cin>>s;bool bltrue;if(s.size()<2)blfalse;else{if(s.substr(0,2)"10"){if(s[2]0)blfalse;else if(s[2]1&&s.size()<3)blfalse; }else blfalse;}if(bl)cout<<"YES\n";else cout…

数据结构第二天

分文件编程实现&#xff1a; main.c #include <stdio.h> #include <stdlib.h> #include "seqlist.h"int main(int argc, char const *argv[]) {seqlist_p p CreateEpSeqlist();InsertIntoSeqlist(p, 0, 10);InsertIntoSeqlist(p, 1, 20);InsertIntoSeql…

打卡第四十五天:不同的子序列、两个字符串的删除操作、编辑距离

一、不同的子序列&#xff08;困难&#xff09; 题目 文章 视频 这道题目如果不是子序列&#xff0c;而是要求连续序列的&#xff0c;那就可以考虑用KMP。相对于72. 编辑距离简单了不少&#xff0c;因为本题相当于只有删除操作&#xff0c;不用考虑替换增加之类的。但相对于…

Mapreduce_Distinct数据去重

MapReduce中数据去重 输入如下的数据&#xff0c;统计其中的地址信息&#xff0c;并对输出的地址信息进行去重 实现方法&#xff1a;Map阶段输出的信息K2为想要去重的内容&#xff0c;利用Reduce阶段的聚合特点&#xff0c;对K2进行聚合&#xff0c;去重。在两阶段中&#xff…

【数据结构篇】~顺序表

顺序表前言 想要学好数据结构的三大基本功&#xff1a;1.结构体2.指针3.动态内存开辟,这三样将是贯彻整个数据结构的工具。&#xff08;可以去这里了解这三大基本功&#xff09; 顺序表也是线性表的一种&#xff0c;那线性表又是什么呢&#xff1f; 线性表&#xff08;linear …

系列:水果甜度个人手持设备检测-无损检测常用技术和方式汇总

系列:水果甜度个人手持设备检测 -- 无损检测常用技术和方式汇总 概述 无损检测以不损坏被检测对象的使用性能为前提&#xff0c;以物理或化学方法为手段&#xff0c;借助相应的设备器材&#xff0c;按照规定的技术要求&#xff0c;对材料、零部件、结构件进行有效的检验和测…