【深度学习实战(11)】搭建训练框架之dataset,dataloader

news2024/11/27 10:23:34

一、dataset和dataloader要点说明

在我们搭建自己的网络时,往往需要定义自己的datasetdataloader,将图像和标签数据送入模型。
(1)在我们定义dataset时,需要继承torch.utils.data.dataset,再重写三个方法:

  • init方法,主要用来定义数据的预处理
  • getitem方法,数据增强;返回数据的item和label
  • len方法,返回数据数量

(2)在我们定义dataloader时,需要考虑下面几个参数:

  • dataset :使用哪个数据集
  • batch_size:将数据集拆成一组多少个进行训练
  • shuffle:是否需要打乱数据
  • num_workers:几个mini_batch并行计算,一般<=你的电脑cpu数目
  • collect_fn:数据打包方式

(3)通过迭代的方式,按批次,获取dataloader中的数据

(4)关系图

在这里插入图片描述

二、核心代码框架

import os
import cv2
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader


# -------------------------------------------------------------#
#   自定义dataset需要继承torch.utils.data.dataset,
#   再重写def __init__,def __len__,def __getitem__三个方法
# -------------------------------------------------------------#
class YourDataset(Dataset):
    def __init__(self,  root_path):
        super(YourDataset, self).__init__()
        self.root_path = root_path
        #-------------------------------------------------------------------------#
        #   获取样本名,以jpg原始图片为参考,修改后缀名为json,png,获取json,png标签文件路径
        #-------------------------------------------------------------------------#
        self.sample_names = []
        jpg_path = os.path.join(os.path.join(self.root_path, "images"),)
        for file in os.listdir(jpg_path):
            if file.endswith(".jpg"):
                self.sample_names.append(os.path.splitext(file)[0]) # 去掉.json

    def __len__(self):
        #----------------------#
        #   返回数据数量
        #----------------------#
        return len(self.sample_names)

    def __getitem__(self, index):
        name = self.sample_names[index]

        # ----------------------#
        #   读取图像
        # ----------------------#
        img_path = os.path.join(os.path.join(self.root_path, "images"), name + '.jpg')
        image = cv2.imread(img_path)
        # ----------------------#
        #   读取标签
        # ----------------------#
        label_path = os.path.join(os.path.join(self.root_path, "jsons"), name + '.json')
        with open(label_path) as label_file:
            points = self.get_data_from_json(label_file)
        #----------------------#
        #   图像数据增强
        #----------------------#
        image = self.random_color(image)
        #----------------------#
        #   标签归一化
        #----------------------#
        labels = self.convert_labels(points)
        return image,  labels

# -------------------------------------#
#   图片和标签格式转换后,按批次(batch)打包
# -------------------------------------#
def dataloader_collate_fn(batch):
    images = []
    labels = []
    for img, label in batch:
        images.append(transforms.ToTensor()(img))
        labels.append(label)
    return images, labels


if __name__ == '__main__':
    # -------------------------------------#
    #   构建dataset
    # -------------------------------------#
    path = './data/train'
    train_dataset = YourDataset(path)

    # -------------------------------------#
    #   构建Dataloader
    # -------------------------------------#
    dataset = train_dataset
    batch_size = 32
    shuffle = True
    num_workers = 0
    collate_fn = dataloader_collate_fn
    sampler = None
    train_gen = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,drop_last=True, collate_fn=collate_fn, sampler=sampler)
    # ---------------------------------------------#
    #   通过迭代的方式,一批一批读取训练集中的图像和标签数据
    # ---------------------------------------------#
    for iter, batch in enumerate(train_gen):
        images,  labels = batch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1625590.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

文本高效拆分内容,根据空行高效拆分文本内容,文本文档管理更轻松

文本文档是我们日常生活和工作中不可或缺的一部分。然而&#xff0c;随着文本内容的不断增加&#xff0c;如何高效、有序地管理这些文档成为了一个挑战。传统的文本编辑工具往往无法满足我们对于文档整理的需求&#xff0c;而手动整理又费时费力。现在&#xff0c;我们为您带来…

Java实战:确定给定日期是一年的第几天

本次实战&#xff0c;我们将探讨如何确定给定日期是一年中的第几天。为此&#xff0c;我们提供了三种不同的方法&#xff0c;每种方法都有其独特的实现方式和适用场景。 方法一&#xff1a;不使用数组 这种方法通过Scanner类获取用户的输入&#xff0c;包括年份、月份和日期。…

从虚拟化走向云原生,红帽OpenShift“一手托两家”

汽车行业已经迈入“软件定义汽车”的新时代。吉利汽车很清醒地意识到&#xff0c;只有通过云原生技术和数字化转型&#xff0c;才能巩固其作为中国领先汽车制造商的地位。 和很多传统企业一样&#xff0c;吉利汽车在走向云原生的过程中也经历了稳态业务与敏态业务并存带来的前所…

WEB攻防-PHP特性-函数缺陷对比

目录 和 MD5函数 intval ​strpos in_array preg_match str_replace 和 使用 时&#xff0c;如果两个比较的操作数类型不同&#xff0c;PHP 会尝试将它们转换为相同的类型&#xff0c;然后再进行比较。 使用 进行比较时&#xff0c;不仅比较值&#xff0c;还比较变量…

网贷大数据黑名单要多久才能变正常?

网贷大数据黑名单是指个人在网贷平台申请贷款时&#xff0c;因为信用记录较差而被列入黑名单&#xff0c;无法获得贷款或者贷款额度受到限制的情况。网贷大数据黑名单的具体时间因个人信用状况、所属平台政策以及银行审核标准不同而异&#xff0c;一般来说&#xff0c;需要一定…

FebHost:注册国外域名优先考虑可用性还是成本?

在选择域名后缀时&#xff0c;应该优先考虑可用性还是成本&#xff1f;这主要取决于您的具体情况。这两个因素都很重要&#xff0c;您应根据自己的需求进行权衡。 可用性方面&#xff1a;热门的域名后缀&#xff0c;如.com和.net&#xff0c;通常需求量较大&#xff0c;因此可…

数字安全实操AG网址漏洞扫描原理与技术手段分析

在数字化世界的大舞台上&#xff0c;网络安全如同守护者一般&#xff0c;默默保卫着我们的虚拟疆界。当我们在享受互联网带来的便利时&#xff0c;一场无形的战争正在上演。黑客们利用各种手段试图攻破网站的安全防线&#xff0c;而防守方则依靠先进的技术和策略来抵御入侵。其…

安卓studio插件开发(一)本地搭建工程

下载idea 社区版本 建立IDE Plugin工程 点击create就行&#xff0c;新建立的工程长这样 比较重要的文件 build.gradle&#xff1a;配置工程的参数 plugin.xml&#xff1a;设置插件的Action位置 build.gradle.kts内容如下&#xff1a; plugins {id("java")id(&quo…

【VTKExamples::Modelling】第四期 MarchingSquares

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例MarchingSquares,并解析接口vtkMarchingSquares,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动力(^U…

c# winform 控件皮肤

控件皮肤下载&#xff1a; https://download.csdn.net/download/m0_46973223/89225992 步骤&#xff1a; 第一步 将IrisSkin4.dll文件放在debug文件下&#xff0c;选一个或者多个后缀名为.ssk文件&#xff08;各个皮肤文件&#xff09;放在debug文件下。 第二步 解决方案资…

【算法刷题】手撕LRU算法(原理、图解、核心思想)

文章目录 1.LRU算法1.1相关概念1.2图解举例1.3基于HashMap和双向链表实现1.3.1核心思想1.3.2代码解读1.3.3全部代码 1.LRU算法 1.1相关概念 LRU&#xff08;Least Recently Used&#xff0c;最近最久未使用算法&#xff09;&#xff1a; 定义&#xff1a;根据页面调入内存后的…

nvm管理多个node版本,快速来回切换node版本

前言 文章基于 windows环境 使用nvm安装多版本nodejs。 最近公司有的项目比较老需要降低node版本才能运行&#xff0c;由于来回进行卸载不同版本的node比较麻烦&#xff1b;所以需要使用node工程多版本管理&#xff0c;后面自己就简单捯饬了一下nvm来管理node&#xff0c;顺便…

【资源分享】Latex2024安装教程

::: block-1 “时问桫椤”是一个致力于为本科生到研究生教育阶段提供帮助的不太正式的公众号。我们旨在在大家感到困惑、痛苦或面临困难时伸出援手。通过总结广大研究生的经验&#xff0c;帮助大家尽早适应研究生生活&#xff0c;尽快了解科研的本质。祝一切顺利&#xff01;—…

路由重分布的概念与配置

路由重分布的概念 l 路由重分布是指连接不同路由域&#xff08;自治系统&#xff09;的边界路由器&#xff0c;它在路由协议之间交换和通告路由信息 从一种协议&#xff08;含静态/直连路由&#xff09;到另一种协议 同一种协议的多个实例 路由重分布的背景 网络出口位置…

宝宝洗衣机买什么样的好?诚意推荐四款实力超群的婴儿洗衣机

近几年家用洗衣机标准容积的大大增加&#xff0c;从5Kg、6Kg升级到9Kg、10Kg。大容量洗衣机满足了家庭中清洗大件衣物、床上用品的需求。但由于普通大型洗衣机所洗衣物混杂&#xff0c;很多时候由于宝宝小件衣物数量不多&#xff0c;却也并不适合放在一起扔进大型洗衣机中清洗。…

lesson05:C++内存管理

1.内存分布 2.c中动态内存管理 3.operator new和operator delete函数 4.new和delete实现原理 1.内存分布 1.1常见的内存分布 1.2相关问题 答案&#xff1a;CCCAA AAADAB 我们讲以下易错的部分&#xff1a; 7.数组char2是在栈上开的空间&#xff0c;然后将"a…

golang学习笔记(net/http库基本使用)

关于net/http库 我们先看看标准库net/http如何处理一个请求。 import ("fmt""log""net/http" )var count 0func main() {http.HandleFunc("/", handler)http.HandleFunc("/count", counter)log.Fatal(http.ListenAndServ…

STM32_舵机的实战

一、配置相应的管脚 二、写代码

【OceanBase诊断调优】——hpet(高精度时钟源)引起的CPU高问题排查

最近总结一些诊断OCeanBase的一些经验&#xff0c;出一个【OceanBase诊断调优】专题出来&#xff0c;也欢迎大家贡献自己的诊断OceanBase的方法。 1. 前言 昨天在问答区帮忙排查一个用户CPU高的问题&#xff0c;帖子链接&#xff1a;《刚刚新安装的OceanBase集群&#xff0c;…

leetcode 221 最大正方形面积

示例 3&#xff1a; 输入&#xff1a;matrix [["0"]] 输出&#xff1a;0 # 最大正方形面积 def max_square(matrix):m len(matrix)n len(matrix[0])if m 0 or n 0::return Nonemax_side 1dp [[0] * (n 1) for _ in range(m 1)]for i in range(1, m 1):fo…