pytorch中dataloader自定义数据集

news2025/4/28 9:29:38

前言

在深度学习中我们需要使用自己的数据集做训练,因此需要将自定义的数据和标签加载到pytorch里面的dataloader里,也就是自实现一个dataloader。

数据集处理

以花卉识别项目为例,我们分别做出图片的训练集和测试集,训练集的标签和测试集的标签

flower_data/
├── train_filelist/
│   ├── image_0001.jpg
│   └── ...
├── val_filelist/
│   ├── image_1001.jpg
│   └── ...
├── train.txt  # 格式:文件名 标签
└── val.txt

 数据目录的组织方式如上所示。

首先看图片的处理。图片只要做好编号放在同一个文件夹里就好了。

再看标签的处理。标签处理我们自己规定了一种形式,就是图像文件的名称+空格+分类标签。

可以看到前面第一列数据是图像名称,第二列数据是图像的分组,同样的数字为一组。比如分组为0的图像就是同一种花朵。

自定义dataset

源码

import os.path
import numpy as np
import torch
from PIL import Image  # 从PIL库导入Image类
from torch.utils.data import Dataset


class FlowerDataSet(Dataset):
    """花朵分类任务数据集类,继承自torch的Dataset类"""

    def __init__(self, root_dir, ann_file, transform=None):
        """
        初始化数据集实例

        Args:
            root_dir (str): 数据集根目录路径
            ann_file (str): 标注文件路径
            transform (callable, optional): 数据预处理变换函数
        """
        self.ann_file = ann_file
        self.root_dir = root_dir
        # 加载图片路径与标签的映射字典 {文件名: 标签}
        self.image_label = self.load_annotations()
        # 构建完整图片路径列表 [root_dir/文件名1, ...]
        self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]
        # 构建标签列表 [标签1, 标签2, ...]
        self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突
        self.transform = transform

    def __len__(self):
        """返回数据集样本数量"""
        return len(self.image)

    def __getitem__(self, index):
        """
        获取单个样本数据

        Args:
            index (int): 样本索引

        Returns:
            tuple: (预处理后的图像数据, 对应的标签)
        """
        # 打开图片文件
        image = Image.open(self.image[index])
        # 获取对应标签
        label = self.label[index]

        # 应用数据预处理
        if self.transform:
            image = self.transform(image)

        # 将标签转换为torch张量
        label = torch.from_numpy(np.array(label))
        return image, label

    def load_annotations(self):
        """
        加载标注文件,解析图片文件名和标签的映射关系

        Returns:
            dict: {图片文件名: 对应标签} 的字典
        """
        data_infos = {}
        with open(self.ann_file) as f:
            # 读取所有行并分割,每行格式应为 "文件名 标签"
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, label in samples:
                # 将标签转换为int64类型的numpy数组
                data_infos[filename] = np.array(label, dtype=np.int64)
        return data_infos

解析

1、将标签数据进行读取,组成一个哈希表,哈希表的键是图像的文件名称,哈希表的值是分组标签。

    def load_annotations(self):
        """
        加载标注文件,解析图片文件名和标签的映射关系

        Returns:
            dict: {图片文件名: 对应标签} 的字典
        """
        data_infos = {}
        with open(self.ann_file) as f:
            # 读取所有行并分割,每行格式应为 "文件名 标签"
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, label in samples:
                # 将标签转换为int64类型的numpy数组
                data_infos[filename] = np.array(label, dtype=np.int64)
        return data_infos

上面的代码里,在录入标签的时候使用数组进行记录,这是为了兼容多标签的场景。如果不考虑兼容问题,仅考虑在单标签场景下的简单实现,可以用下面的代码:

def load_annotations(self):
    data_infos = {}
    with open(self.ann_file) as f:
        for line in f:
            filename, label = line.strip().split()  # 直接解包
            data_infos[filename] = int(label)        # 存为 Python 整数
    return data_infos

# 在 __getitem__ 中直接转为张量
label = torch.tensor(self.labels[index], dtype=torch.long)

2、遍历哈希表,将文件名和标签分别存在两个数组里。这里注意,为了方便后面dataloader按照batch去读取图片,这里要将图片的全路径加到文件名里。

        # 构建完整图片路径列表 [root_dir/文件名1, ...]
        self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]
        # 构建标签列表 [标签1, 标签2, ...]
        self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突

3、在dataloader向显卡/cpu加载数据的时候会调用getitem方法。比如一个batch里有64个数据,dataloader就会调用64次该方法,将64组图片和标签全部获取后交给运算单元去处理。

    def __getitem__(self, index):
        """
        获取单个样本数据

        Args:
            index (int): 样本索引

        Returns:
            tuple: (预处理后的图像数据, 对应的标签)
        """
        # 打开图片文件
        image = Image.open(self.image[index])
        # 获取对应标签
        label = self.label[index]

        # 应用数据预处理
        if self.transform:
            image = self.transform(image)

        # 将标签转换为torch张量
        label = torch.from_numpy(np.array(label))
        return image, label

测试dataloader

import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from dataloader import FlowerDataSet  # 假设你的数据集类在dataloader.py中


def denormalize(image_tensor):
    """将归一化的图像张量转换为可显示的格式"""
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = image_tensor.numpy().transpose((1, 2, 0))  # 转换维度顺序
    image = std * image + mean  # 反归一化
    image = np.clip(image, 0, 1)  # 限制像素值范围
    return image


def test_dataloader():
    # 定义数据预处理
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(64),
            transforms.RandomRotation(45),
            transforms.CenterCrop(64),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'valid': transforms.Compose([
            transforms.Resize(64),
            transforms.CenterCrop(64),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    # 检查文件路径是否存在
    print("[1/5] 检查文件路径...")
    required_files = {
        'train_txt': './flower_data/train.txt',
        'val_txt': './flower_data/val.txt',
        'train_dir': './flower_data/train_filelist',
        'val_dir': './flower_data/val_filelist'
    }

    for name, path in required_files.items():
        if not os.path.exists(path):
            print(f"❌ 文件/目录不存在: {path}")
            return
        print(f"✅ {name}: {path} 存在")

    # 初始化数据集
    print("\n[2/5] 加载数据集...")
    try:
        train_dataset = FlowerDataSet(
            root_dir=required_files['train_dir'],
            ann_file=required_files['train_txt'],
            transform=data_transforms['train']
        )
        val_dataset = FlowerDataSet(
            root_dir=required_files['val_dir'],
            ann_file=required_files['val_txt'],
            transform=data_transforms['valid']
        )
        print("✅ 数据集加载成功")
    except Exception as e:
        print(f"❌ 数据集加载失败: {str(e)}")
        return

    # 打印数据集信息
    print("\n[3/5] 数据集统计:")
    print(f"训练集样本数: {len(train_dataset)}")
    print(f"验证集样本数: {len(val_dataset)}")

    # 检查单个样本
    print("\n[4/5] 检查单个样本:")
    sample_idx = 0
    try:
        img, label = train_dataset[sample_idx]
        print(f"图像张量形状: {img.shape} (应接近 torch.Size([3, 64, 64]))")
        print(f"标签类型: {type(label)} (应为 torch.Tensor)")
        print(f"标签值: {label.item()} (应为整数)")
    except Exception as e:
        print(f"❌ 样本检查失败: {str(e)}")

    # 可视化样本
    print("\n[5/5] 可视化训练集样本...")
    try:
        plt.figure(figsize=(8, 8))
        img_show = denormalize(img)
        plt.imshow(img_show)
        plt.title(f"Label: {label.item()}")
        plt.axis('off')
        plt.show()
    except Exception as e:
        print(f"❌ 可视化失败: {str(e)}")

    # 检查DataLoader
    print("\n[附加] 检查DataLoader:")
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

    for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:
        print(f"\n{name} DataLoader测试:")
        try:
            batch = next(iter(loader))
            images, labels = batch
            print(f"批次图像形状: {images.shape} (应接近 [batch, 3, 64, 64])")
            print(f"批次标签示例: {labels[:5].numpy()}")
            print(f"像素值范围: [{images.min():.3f}, {images.max():.3f}]")
        except Exception as e:
            print(f"❌ {name} DataLoader错误: {str(e)}")


if __name__ == '__main__':
    test_dataloader()

在测试代码中,分别测试了文件路径,dataset是否正常创建,dataset样本数量,dataset样本格式,dataset数据可视化,dataloader数据样式。

在打印日志的时候需要注意,dataset和dataloader里面的变量都是张量形式的,所以需要转换成python标量再打印。比如从dataset里取出的标签label是一个一维张量,需要通过label.item()进行转换。

 在遍历的时候为了简化代码,将两个dataloader放在同一个循环语句中处理,并且通过增加name变量来区分两个dataloader。

for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2326654.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL Server:触发器

在 SQL Server Management Studio (SSMS) 中查看数据库触发器的方法如下: 方法一:通过对象资源管理器 连接到 SQL Server 打开 SSMS,连接到目标数据库所在的服务器。 定位到数据库 在左侧的 对象资源管理器 中,展开目标数据库&a…

标题:利用 Rork 打造定制旅游计划应用程序:一步到位的指南

引言: 在数字化时代,旅游计划应用程序已经成为旅行者不可或缺的工具。但开发一个定制的旅游应用可能需要耗费大量时间与精力。好消息是,Rork 提供了一种快捷且智能的解决方案,让你能轻松实现创意。以下是使用 Rork 创建一个定制旅…

WebSocket原理详解(二)

WebSocket原理详解(一)-CSDN博客 目录 1.WebSocket协议的帧数据详解 1.1.帧结构 1.2.生成数据帧 2.WebSocket协议控制帧结构详解 2.1.关闭帧 2.2.ping帧 2.3.pong帧 3.WebSocket心跳机制 1.WebSocket协议的帧数据详解 1.1.帧结构 WebSocket客户端与服务器通信的最小单…

计算声音信号波形的谐波

计算声音信号波形的谐波 1、效果 2、定义 在振动分析中,谐波通常指的是信号中频率是基频整数倍的成分。基频是振动的主要频率,而谐波可能由机械系统中的非线性因素引起。 3、流程 1. 信号生成:生成或加载振动信号数据(模拟或实际数据)。 2. 预处理:预处理数据,如去噪…

RepoReporter 仿照`TortoiseSVN`项目监视器,能够同时支持SVN和Git仓库

RepoReporter 项目地址 RepoReporter 一个仓库监视器,仿照TortoiseSVN项目监视器,能够同时支持SVN和Git仓库。 工作和学习会用到很多的仓库,每天都要花费大量的时间在频繁切换文件夹来查看日志上。 Git 的 GUI 工具琳琅满目,Git…

UI设计系统:如何构建一套高效的设计规范?

UI设计系统:如何构建一套高效的设计规范? 1. 色彩系统的建立与应用 色彩系统是设计系统的基础之一,它不仅影响界面的整体美感,还对用户体验有着深远的影响。首先,设计师需要定义主色调、辅助色和强调色,并…

【计算机网络】记录一次校园网无法上网的解决方法

问题现象 环境:实训室教室内时间:近期突然出现 (推测是学校在施工,部分设备可能出现问题)症状: 连接校园网 SWXY-WIFI 后: 连接速度极慢偶发无 IP 分配(DHCP 失败)即使分…

第二十一章:Python-Plotly库实现数据动态可视化

Plotly是一个强大的Python可视化库,支持创建高质量的静态、动态和交互式图表。它特别擅长于绘制三维图形,能够直观地展示复杂的数据关系。本文将介绍如何使用Plotly库实现函数的二维和三维可视化,并提供一些优美的三维函数示例。资源绑定附上…

系统思考反馈

最近交付的都是一些持续性的项目,越来越感觉到,系统思考和第五项修炼不只是简单的一门课程,它们能真正融入到我们的日常工作和业务中,帮助我们用更清晰的思维方式解决复杂问题,推动团队协作,激发创新。 特…

【C++】vector常用方法总结

📝前言: 在C中string常用方法总结中我们讲述了string的常见用法,vector中许多接口与string类似,作者水平有限,所以这篇文章我们主要通过读vector官方文档的方式来学习vector中一些较为常见的重要用法。 🎬个…

2025年数智化电商产业带发展研究报告260+份汇总解读|附PDF下载

原文链接:https://tecdat.cn/?p41286 在数字技术与实体经济深度融合的当下,数智化产业带正成为经济发展的关键引擎。 从云南鲜花产业带的直播热销到深圳3C数码的智能转型,数智化正重塑产业格局。2023年数字经济规模突破53.9万亿元&#xff…

Linux中常用服务器监测命令(性能测试监控服务器实用指令)

1.查看进程 ps -ef|grep 进程名以下指令需要先安装:sysstat,安装指令: yum install sysstat2.查看CPU使用情况(间隔1s打印一个,打印6次) sar -u 1 63.#查看内存使用(间隔1s打印一个,打印6次) sar -r 1 6

基于 GEE 的区域降水数据可视化:从数据处理到等值线绘制

目录 1 引言 2 代码功能概述 3 代码详细解析 3.1 几何对象处理与地图显示 3.2 加载 CHIRPS 降水数据 3.3 筛选不同时间段的降水数据 3.4 绘制降水时间序列图 3.5 计算并可视化短期和长期降水总量 3.6 绘制降水等值线图 4 总结 5 完整代码 6 运行结果 1 引言 在气象…

曲线拟合 | Matlab基于贝叶斯多项式的曲线拟合

效果一览 代码功能 代码功能简述 目标:实现贝叶斯多项式曲线拟合,动态展示随着数据点逐步增加,模型后验分布的更新过程。 核心步骤: 数据生成:在区间[0,1]生成带噪声的正弦曲线作为训练数据。 参数设置&#xff1a…

Qt6调试项目找不到Bluetooth Component蓝牙组件

错误如图所示 Failed to find required Qt component "Bluetooth" 解决方法:搜索打开Qt maintenance tool 工具 打开后,找到这个Qt Connectivity,勾选上就能解决该错误

JAVA- 锁机制介绍 进程锁

进程锁 基于文件的锁基于Socket的锁数据库锁分布式锁基于Redis的分布式锁基于ZooKeeper的分布式锁 实际工作中都是集群部署,通过负载均衡多台服务器工作,所以存在多个进程并发执行情况,而在每台服务器中又存在多个线程并发的情况,…

Java Spring Boot 与前端结合打造图书管理系统:技术剖析与实现

目录 运行展示引言系统整体架构后端技术实现后端代码文件前端代码文件1. 项目启动与配置2. 实体类设计3. 控制器设计4. 异常处理 前端技术实现1. 页面布局与样式2. 交互逻辑 系统功能亮点1. 分页功能2. 搜索与筛选功能3. 图书操作功能 总结 运行展示 引言 本文将详细剖析一个基…

深入剖析JavaScript多态:从原理到高性能实践

摘要 JavaScript多态作为面向对象编程的核心特性,在动态类型系统的支持下展现了独特的实现范式。本文深入解析多态的三大实现路径:参数多态、子类型多态与鸭子类型,详细揭示它们在动态类型系统中的理论基础与实践意义。结合V8引擎的优化机制…

GalTransl开源程序支持GPT-4/Claude/Deepseek/Sakura等大语言模型的Galgame自动化翻译解决方案

一、软件介绍 文末提供程序和源码下载 GalTransl是一套将数个基础功能上的微小创新与对GPT提示工程(Prompt Engineering)的深度利用相结合的Galgame自动化翻译工具,用于制作内嵌式翻译补丁。支持GPT-4/Claude/Deepseek/Sakura等大语言模型的…

TGES 2024 | 基于空间先验融合的任意尺度高光谱图像超分辨率

Arbitrary-Scale Hyperspectral Image Super-Resolution From a Fusion Perspective With Spatial Priors TGES 2024 10.1109/TGRS.2024.3481041 摘要:高分辨率高光谱图像(HR-HSI)在遥感应用中起着至关重要的作用。单HSI超分辨率&#xff…