Pytorch中Dataset和dadaloader的理解

news2024/11/17 16:40:12

不同的数据集在形式上千差万别,为了能够统一用于模型的训练,Pytorch框架下定义了一个dataset类和一个dataloader类。

dataset用于获取数据集中的样本,dataloader 用于抽取部分样本用于训练。比如说一个用于分割任务的图像数据集的结构如图1所示,一个样本由原图像和对应的mask组成。

图1 典型数据集的结构

为了获取数据集,典型的代码如下

from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms

# 定义数据集
train_data_dir = 'dataset/train'
train_GT_dir = 'dataset/train_GT'

class MyData(Dataset):
    def __init__(self, imgdir, maskdir,transform):
        self.imgdir = imgdir
        self.maskdir = maskdir
        self.transform = transform
        self.img_list = os.listdir(self.imgdir)
        self.mask_list= os.listdir(self.maskdir)
        self.img_list.sort()
        self.mask_list.sort()

    def __getitem__(self, idx):
        img_name = self.img_list[idx]
        mask_name =self.mask_list[idx]
        img_item_path = os.path.join(self.imgdir, img_name)
        mask_item_path =os.path.join(self.maskdir,mask_name)

        img =Image.open(img_item_path)
        mask =Image.open(mask_item_path)

        img = self.transform(img)
        mask = self.transform(mask)

        return img, mask

    def __len__(self):
        assert len(self.img_list) == len(self.mask_list)
        return len(self.img_list)

if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    train_data_dir = 'dataset/train'
    train_GT_dir = 'dataset/train_GT'
    dataset = MyData(train_data_dir, train_GT_dir ,transform)
    dataloader = DataLoader(dataset, batch_size=4, num_workers=0)
    for step, (img,mask) in enumerate(dataloader):
        print(step)
        print(img.shape)
        print(mask.shape)
        if step>0:
            break

程序运行的结果如下:

返回了一个batch的img 和mask 的尺寸,说明数据集抽取成功了.

在建立数据集的过程中需用重写__getitem()__和__len()__方法即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1407571.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯---三羊献瑞

观察下面的加法算式: 其中,相同的汉字代表相同的数字,不同的汉字代表不同的数字。 请你填写“三羊献瑞”所代表的4位数字(答案唯一),不要填写任何多余内容。 答案 代码 public class _03三羊献瑞 {public static void main(String[] args) {//c 生 b 瑞 g 献 d 辉…

Flink多流转换(1)—— 分流合流

目录 分流 代码示例 使用侧输出流 合流 联合(Union) 连接(Connect) 简单划分的话,多流转换可以分为“分流”和“合流”两大类 目前分流的操作一般是通过侧输出流(side output)来实现&…

【后端技术】术有千法,道本归一

目录 1.概述 2.机器的问题 2.1.计算 2.2.存储 2.3.传输 3.人的问题 3.1.代码工程的管理 3.2.过程的把控 4.总结 1.概述 术有千法,道本归一。 之所以这样说,是因为当前出现的纷繁复杂的后端技术,其本质其实都是为了解决同一套问题。…

蓝桥杯题目解析 --矩形切割

先看题: 题目中的例子解析: 5*3 切一刀最大的,肯定是3*3,切完后只剩2*3,切一刀最大的,肯定是2*2,切完后只剩2*1,切一刀最大的,肯定是1*1,切完后只剩1*1&…

浅谈手机APP测试(流程)

前言 APP测试是一个广泛的概念,根据每个app的应用场景不一样,测试的方向也略微的不同,在测试过程中需要灵活应用自身所知的测试手段。 今天就跟大家简单聊聊手机APP测试的一些相关内容。 APP开发流程 (1) 拿到需求分…

【STM32】USB程序烧录需要重新上电 软件复位方法

文章目录 一、问题二、解决思路2.1 直接插拔USB2.2 给芯片复位 三、解决方法3.1 别人的解决方法3.2 在下载界面进行设置 一、问题 最近学习STM32的USB功能,主要是想要使用虚拟串口功能(VCP),发现每次烧录之后都需要重新上电才可以…

Parade Series - Android Studio

硬件支持 CPU i7 RAM 16Gb -------------- ------- Java 3Gb Android 33GbJava Enviroment C:\ ├─ Java │ ├─ jdk1.8.0_181 │ ├─ jre1.8.0_181 │ ├─ maven-3.8.5 │ └─ gradle-6.5 └─ Cache├─ gr…

Python实现力扣经典面试题——删除有序数组中的重复项

题目:删除有序数组中的重复项 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。考虑 nu…

在Word中插入高亮/好看代码

Md2All 一个markdown工具 Md2All 网址 代码一定要用code 高亮主题可选 atom-one-light > 复制到word > 调节字体可选Cnsolas, 间距等 效果 另一个高亮工具 效果

算法/结构/理论复习1---理论基础----更新中

算法/结构/理论 雪花算法CAP理论BASE理论分布式事务的解决方案数据结构树(Tree)二叉树二叉查找树平衡查找树红黑树(重点)BTree(重点)BTree 雪花算法 雪花算法主要是为了解决在分布式中id的生成问题 分布式id的生成规则是:全局唯一,不可以出现重复的id号,趋势递增 雪花算法指的…

ITK编译及安装

文章目录 前言CMake配置选项说明运行VS2015编译及安装VTK转ITKITK转VTK参考文献 最近想利用ITK读取整个Dicom图像到内存,再将读取到的ITK数据转换到VTK。于是乎,开始了一段ITK编译之路。以下将记录一些有用的信息,以备后用。 前言 DICOM图像…

Spring扩展点在微服务应用(待完善)

ApplicationListener扩展 nacos注册服务, 监听容器发布事件 # 容器发布事件 AbstractAutoServiceRegistration#onApplicationEvent # 接收事件吗,注册服务到nacos NacosServiceRegistry#register Lifecycle扩展 #订阅服务实例更改的事件 NamingService#…

超实用桌面助手!时间、日期、天气,一目了然!完全免费!

文章目录 📖 介绍 📖🏡 环境 🏡📒 使用方法 📒⚓️ 相关链接 ⚓️ 📖 介绍 📖 这是一款我根据自己的需求写的一个桌面小工具,自己一直在用,现在分享给需要的朋…

RabbitMQ发布确认

生产者将信道设置成 confirm 模式,一旦信道进入 confirm 模式, 所有在该信道上面发布的消息都将会被指派一个唯一的 ID(从 1 开始),一旦消息被投递到所有匹配的队列之后, broker就会发送一个确认给生产者(包含消息的唯一 ID)&…

【每日一题】最长交替子数组

文章目录 Tag题目来源解题思路方法一:双层循环方法二:单层循环 写在最后 Tag 【双层循环】【单层循环】【数组】【2024-01-23】 题目来源 2765. 最长交替子数组 解题思路 两个方法,一个是双层循环,一个是单层循环。 方法一&am…

Likeshop单商户SaaS商城源码系统-商家用过都说太香啦!

在互联网快速发展的时代,拥有一个个性化、功能丰富的在线商城是企业拓展市场、提高用户粘性的重要手段。 我是一名电商从业者,同时也是一个热衷于DIY的人,我总喜欢在自己的店铺中加入自己的一些想法和创意。然而,一般的电商平台无…

【思路合集】talking head generation+stable diffusion

1 以DiffusionVideoEditing为baseline: 改进方向 针对于自回归训练方式可能导致的漂移问题: 训练时,在前一帧上引入小量的面部扭曲,模拟在生成过程中自然发生的扭曲。促使模型查看身份帧以进行修正。在像VoxCeleb或LRS这样的具…

EasyX的安装与使用(VisualStudio C++免费绘图库)

EasyX Graphics Library 是针对 Visual C 的免费绘图库 安装教程 安装到Visual C 2010 EasyX 安装完毕。 在VC2010中建立控制台工程 工程建好后,鼠标右键点击工程名,并选择属性 安装到Visual C 2010 EasyX 安装完毕。 安装示例程序 easyxdemo.cpp 在VC…

Vulnhub-dc4

靶场下载 https://download.vulnhub.com/dc/DC-4.zip 信息收集 判断目标靶机的存活地址: # nmap -sT --min-rate 10000 -p- 192.168.1.91 -oN port.nmap Starting Nmap 7.94 ( https://nmap.org ) at 2024-01-21 16:36 CST Stats: 0:00:03 elapsed; 0 hosts completed (1 up…

机器学习 | 掌握Matplotlib的可视化图表操作

Matplotlib是python的一个数据可视化库,用于创建静态、动态和交互式图表。它可以制作多种类型的图表,如折线图、散点图、柱状图、饼图、直方图、3D 图形等。以渐进、交互式方式实现数据可视化。当然博主也不能面面俱到的讲解到所有内容,详情请…