【深度学习实战(11)】搭建自己的dataset和dataloader

news2025/1/12 8:05:11

一、dataset和dataloader要点说明

在我们搭建自己的网络时,往往需要定义自己的datasetdataloader,将图像和标签数据送入模型。
(1)在我们定义dataset时,需要继承torch.utils.data.dataset,再重写三个方法:

  • init方法,主要用来定义数据的预处理
  • getitem方法,数据增强;返回数据的item和label
  • len方法,返回数据数量

(2)在我们定义dataloader时,需要考虑下面几个参数:

  • dataset :使用哪个数据集
  • batch_size:将数据集拆成一组多少个进行训练
  • shuffle:是否需要打乱数据
  • num_workers:几个mini_batch并行计算,一般<=你的电脑cpu数目
  • collect_fn:数据打包方式

(3)通过迭代的方式,按批次,获取dataloader中的数据

(4)关系图

在这里插入图片描述

二、核心代码框架

import os
import cv2
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader


# -------------------------------------------------------------#
#   自定义dataset需要继承torch.utils.data.dataset,
#   再重写def __init__,def __len__,def __getitem__三个方法
# -------------------------------------------------------------#
class YourDataset(Dataset):
    def __init__(self,  root_path):
        super(YourDataset, self).__init__()
        self.root_path = root_path
        #-------------------------------------------------------------------------#
        #   获取样本名,以jpg原始图片为参考,修改后缀名为json,png,获取json,png标签文件路径
        #-------------------------------------------------------------------------#
        self.sample_names = []
        jpg_path = os.path.join(os.path.join(self.root_path, "images"),)
        for file in os.listdir(jpg_path):
            if file.endswith(".jpg"):
                self.sample_names.append(os.path.splitext(file)[0]) # 去掉.json

    def __len__(self):
        #----------------------#
        #   返回数据数量
        #----------------------#
        return len(self.sample_names)

    def __getitem__(self, index):
        name = self.sample_names[index]

        # ----------------------#
        #   读取图像
        # ----------------------#
        img_path = os.path.join(os.path.join(self.root_path, "images"), name + '.jpg')
        image = cv2.imread(img_path)
        # ----------------------#
        #   读取标签
        # ----------------------#
        label_path = os.path.join(os.path.join(self.root_path, "jsons"), name + '.json')
        with open(label_path) as label_file:
            points = self.get_data_from_json(label_file)
        #----------------------#
        #   图像数据增强
        #----------------------#
        image = self.random_color(image)
        #----------------------#
        #   标签归一化
        #----------------------#
        labels = self.convert_labels(points)
        return image,  labels

# -------------------------------------#
#   图片和标签格式转换后,按批次(batch)打包
# -------------------------------------#
def dataloader_collate_fn(batch):
    images = []
    labels = []
    for img, label in batch:
        images.append(transforms.ToTensor()(img))
        labels.append(label)
    return images, labels


if __name__ == '__main__':
    # -------------------------------------#
    #   构建dataset
    # -------------------------------------#
    path = './data/train'
    train_dataset = YourDataset(path)

    # -------------------------------------#
    #   构建Dataloader
    # -------------------------------------#
    dataset = train_dataset
    batch_size = 32
    shuffle = True
    num_workers = 0
    collate_fn = dataloader_collate_fn
    sampler = None
    train_gen = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True,drop_last=True, collate_fn=collate_fn, sampler=sampler)
    # ---------------------------------------------#
    #   通过迭代的方式,一批一批读取训练集中的图像和标签数据
    # ---------------------------------------------#
    for iter, batch in enumerate(train_gen):
        images,  labels = batch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1609533.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode in Python 55. Jump Game (跳跃游戏)

跳跃游戏的游戏规则比较简单&#xff0c;若单纯枚举所有的跳法以判断是否能到达最后一个下标需要的时间复杂度为O()&#xff0c;为此&#xff0c;本文采用贪心策略&#xff0c;从最后一个下标开始逆着向前走&#xff0c;若能跳到第一个元素则表明可以完成跳跃游戏&#xff0c;反…

Python基础学习之数据切片

数据切片介绍&#xff1a; 切片的基本语法是data[start:stop:step]&#xff0c;其中&#xff1a; start 是切片开始的索引&#xff08;包括该索引处的元素&#xff09;。 stop 是切片结束的索引&#xff08;不包括该索引处的元素&#xff09;。 step 是切片的步长&#xff0…

Go小技巧易错点100例(十五)

本期看点&#xff1a; 正文开始&#xff1a; Go程序跟踪函数的执行时间 在Go程序中我们经常会对接口执行的耗时做一个记录&#xff0c;特别是针对核心或复杂业务的时候&#xff0c;我们需要关注该业务的执行耗时&#xff0c;可以具体到某个方法&#xff0c;有一个简单有效的技…

Linux:常用软件、工具和周边知识介绍

上次也是结束了权限相关的知识&#xff1a;Linux&#xff1a;权限相关知识详解 文章目录 1.yum-管理软件包的工具1.1基本介绍1.2yum的使用1.3yum的周边生态1.4软件包介绍 2.vim-多模式的文本编辑器2.1基本介绍2.2基本模式介绍2.2.1命令模式&#xff08;Normal mode&#xff09;…

Python数据可视化库—Bokeh与Altair指南

&#x1f47d;发现宝藏 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。【点击进入巨牛的人工智能学习网站】。 在数据科学和数据分析领域&#xff0c;数据可视化是一种强大的工具&#xff0c;可以帮助我们…

循序渐进丨使用 Python 向 MogDB 数据库批量操作数据的方法

当我们有时候需要向数据库里批量插入数据&#xff0c;或者批量导出数据时&#xff0c;除了使用传统的gsql copy命令&#xff0c;也可以通过Python的驱动psycopg2进行批量操作。本文介绍了使用psycopg2里的executemany、copy_from、copy_to、copy_expert等方式来批量操作 MogDB …

Darknet框架优化介绍

一、DarkNet框架简介 1.DarkNet的简介 Darknet是一个完全使用C语言编写的人工智能框架&#xff0c;可以使用CUDA的开源框架。主要应用于图像识别领域。 它具有可移植性好&#xff0c;安装间接&#xff0c;查看源码方便等优势&#xff0c;提供了OpenCV等附加选项&#xff0c;还…

【SpringBoot实战篇】获取用户详细信息-ThreadLocal优化

1 分析问题 对token的解析当初在拦截器中已经写过。期待的是在拦截器里写了&#xff0c;在其他地方就不写了&#xff0c;应该去复用拦截器里面得到的结果 2 解决方式-ThreadLocal 2.1提供线程局部变量 用来存取数据: set()/get()使用ThreadLocal存储的数据, 线程安全 2.2过程图…

Java如何用EasyExcel插件对Excel进行数据导入和数据导出

文章目录 一、EasyExcel的示例导入依赖创建实体类数据导入和导出 二、EasyExcel的作用三、EasyExcel的注解 EasyExcel是一个阿里巴巴开源的excel处理框架&#xff0c;它以使用简单、节省内存著称。在解析Excel时&#xff0c;EasyExcel没有将文件数据一次性全部加载到内存中&…

java-springmvc 01

MVC就是和Tomcat有关。 01.MVC启动的第一步&#xff0c;启动Tomcat 02.Tomcat会解析web-inf的web.xml文件

【Flutter】自动生成图片资源索引插件二:FlutterAssetsGenerator

介绍 FlutterAssetsGenerator 插件 &#xff1a;没乱码&#xff0c;生成的图片索引命名是小驼峰 目录 介绍一、安装二、使用 一、安装 1.安装FlutterAssetsGenerator 插件 生成的资源索引类可以修改名字&#xff0c;我这里改成R 2. 根目录下创建assets/images 3. 点击image…

上网行为管理系统功能介绍_上网行为管理实现的功能

上网行为管理系统是一种集成了网络监控、行为分析、策略管理和安全控制等功能的综合性软件解决方案。 它通过对企业内部网络的全面监控和深度分析&#xff0c;帮助管理者了解员工的网络使用习惯、识别潜在风险、优化网络资源配置&#xff0c;并最终实现网络安全和效率的双重提…

实在IDP文档审阅产品导引

实在IDP文档审阅&#xff1a;智能文档处理的革新者 一、引言 在数字化转型的浪潮中&#xff0c;文档处理的智能化成为企业提效的关键。实在智能科技有限公司推出的实在IDP文档审阅&#xff0c;是一款利用AI技术快速理解、处理文档的智能平台&#xff0c;旨在为企业打造专属的…

偏微分方程算法之二阶双曲型方程显式差分法

目录 一、研究目标 二、理论推导 2.1 三层显格式建立 2.2 三层显格式改进 三、算例实现 3.1 一阶显格式 3.2 二阶显格式 一、研究目标 介绍完一阶双曲型偏微分方程的几种差分格式后&#xff0c;我们继续探讨二阶方程的差分格式。这里以非齐次二阶双曲型偏微分方程的初边…

Android11 SystemUI clock plugin 插件入门

插件的编写 参照ExamplePlugin&#xff0c;需要系统签名。 需要先编译以下模块得到jar&#xff0c;引用在项目中。 m SystemUIPluginLibcom.android.systemui.permission.PLUGIN PluginManager.addPluginListener SystemUI 是如何发现 clock plugin 的&#xff1f; Syste…

【管理咨询宝藏78】MBB大型城投集团核心能力建设分析报告

本报告首发于公号“管理咨询宝藏”&#xff0c;如需阅读完整版报告内容&#xff0c;请查阅公号“管理咨询宝藏”。 【管理咨询宝藏78】MBB大型城投集团核心能力建设分析报告 【格式】PDF版本 【关键词】战略规划、商业分析、管理咨询、MBB顶级咨询公司 【强烈推荐】 这是一套…

【服务器部署篇】Linux下Redis安装

作者介绍&#xff1a;本人笔名姑苏老陈&#xff0c;从事JAVA开发工作十多年了&#xff0c;带过刚毕业的实习生&#xff0c;也带过技术团队。最近有个朋友的表弟&#xff0c;马上要大学毕业了&#xff0c;想从事JAVA开发工作&#xff0c;但不知道从何处入手。于是&#xff0c;产…

【数据结构】单链表经典算法题的巧妙解题思路

目录 题目 1.移除链表元素 2.反转链表 3.链表的中间节点 4.合并两个有序链表 5.环形链表的约瑟夫问题 解析 题目1&#xff1a;创建新链表 题目2&#xff1a;巧用三个指针 题目3&#xff1a;快慢指针 题目4&#xff1a;哨兵位节点 题目5&#xff1a;环形链表 介绍完了…

cdh cm界面HDFS爆红:不良 : 该 DataNode 当前有 1 个卷故障。 临界阈值:任意。(Linux磁盘修复)

一、表现 1.cm界面 报错卷故障 检查该节点&#xff0c;发现存储大小和其他节点不一致&#xff0c;少了一块物理磁盘 2.查看该磁盘 目录无法访问 dmesg检查发现错误 dmesg | grep error二、解决办法 移除挂载 umount /data10 #可以移除挂载盘&#xff0c;或者移除挂载目…

【爬虫】多线程爬取图片

多线程爬虫 多线程爬虫概述1.1 多线程的优势1.2 多线程的挑战 设计多线程爬虫1.1 项目设计1.2 项目流程1.3注意事项 总结 多线程爬虫概述 在当今信息爆炸的时代&#xff0c;网络爬虫&#xff08;Web Scraper&#xff09;已成为获取和分析网络数据的重要工具。而多线程爬虫&…