如何使用pytorch的Dataset, 来定义自己的Dataset

news2025/1/12 1:52:36

Dataset与DataLoader的关系

在这里插入图片描述
在这里插入图片描述

  1. Dataset: 构建一个数据集,其中含有所有的数据样本
  2. DataLoader:将构建好的Dataset,通过shuffle、划分batch、多线程num_workers运行的方式,加载到可训练的迭代容器。
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    """创建自己的数据集"""
    def __init__(self):
        """初始化构建数据集所需要的参数"""
        pass

    def __getitem__(self, index):
        """来获取数据集中样本的索引"""
        pass

    def __len__(self):
        """获取数据集中的样本个数"""
        pass

# 实例化自定义的数据集
dataset = MyDataset()
# 将自定义的数据集加载到可训练的迭代容器
train_loader = DataLoader(dataset=dataset,  # 自定义的数据集
                          batch_size=32,  # 数据集中小批量的大小
                          shuffle=True,  # 是否要打乱数据集中样本的次序
                          num_workers=2)  # 是否要并行

实战1:CSV数据集(结构化数据集)

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    """创建自己的数据集"""
    def __init__(self, filepath):
        """初始化构建数据集所需要的参数"""
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]  # 查看数据集中样本的个数
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        print("数据已准备好......")

    def __getitem__(self, index):
        """为了支持下标操作, 即索引dataset[index]:来获取数据集中样本的索引"""
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        """为了使用len(dataset):获取数据集中的样本个数"""
        return self.len

file = "D:\\BaiduNetdiskDownload\\Dataset_Dataload\\diabetes1.csv"

""" 1.使用 MyDataset类 构建自己的dataset """
mydataset = MyDataset(file)
""" 2.使用 DataLoader 构建train_loader """
train_loader = DataLoader(dataset=mydataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=0)

class MyModel(torch.nn.Module):
    """定义自己的模型"""
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmooid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmooid(self.linear1(x))
        x = self.sigmooid(self.linear2(x))
        x = self.sigmooid(self.linear3(x))
        return x

# 实例化模型
model = MyModel()

# 定义损失函数
criterion = torch.nn.BCELoss(size_average=True)
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)


if __name__ == "__main__":
    for epoch in range(10):
        for i, data in enumerate(train_loader, 0):
            # 1. 准备数据
            inputs, labels = data

            # 2. 前向传播
            y_pred= model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())

            # 3. 反向传播
            optimizer.zero_grad()
            loss.backward()

            # 4. 梯度更新
            optimizer.step()

在这里插入图片描述

实战2:图片数据集

├── flower_data
—├── flower_photos(解压的数据集文件夹,3670个样本)
—├── train(生成的训练集,3306个样本)
—└── val(生成的验证集,364个样本)

主函数文件main.py
import os

import torch
from torchvision import transforms

from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image

root = "../data/flower_data/flower_photos"  # 数据集所在根目录


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)

    # plot_data_loader_image(train_loader)

    for epoch in range(100):
	    for step, data in enumerate(train_loader):
	        images, labels = data
	        # 然后在进行相应的训练操作即可


if __name__ == '__main__':
    main()

自定义数据集文件my_dataset.py
from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels


功能文件utils.py(训练集、验证集的划分与可视化)
import os
import json
import pickle
import random

import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)  # 判断路径是否存在

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引: 字典{’花名‘:0,’花名‘:1,···}
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)  # 将花名与对应的序号分行保存
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    plot_image = True
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def plot_data_loader_image(data_loader):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)

    json_path = './class_indices.json'
    assert os.path.exists(json_path), json_path + " does not exist."
    json_file = open(json_path, 'r')
    class_indices = json.load(json_file)

    for data in data_loader:
        images, labels = data
        for i in range(plot_num):
            # [C, H, W] -> [H, W, C]
            img = images[i].numpy().transpose(1, 2, 0)
            # 反Normalize操作
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            label = labels[i].item()
            plt.subplot(1, plot_num, i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])  # 去掉x轴的刻度
            plt.yticks([])  # 去掉y轴的刻度
            plt.imshow(img.astype('uint8'))
        plt.show()


def write_pickle(list_info: list, file_name: str):
    with open(file_name, 'wb') as f:
        pickle.dump(list_info, f)


def read_pickle(file_name: str) -> list:
    with open(file_name, 'rb') as f:
        info_list = pickle.load(f)
        return info_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1401727.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ElasticSearch 7.x现网运行问题汇集1

问题描述: 现网ElasticSearch health状态变为red,有分片无法assign。如下摘录explain的结果部分: "note": "No shard was specified in the explain API request, so this response explains a randomly chosen unassigned s…

基于BERT对中文邮件内容分类

用BERT做中文邮件内容分类 项目背景与意义项目思路数据集介绍环境配置数据加载与预处理自定义数据集模型训练加载BERT预训练模型开始训练 预测效果 项目背景与意义 本文是《用BERT做中文邮件内容分类》系列的第二篇,该系列项目持续更新中。系列的起源是《使用Paddl…

【前端设计】card

欢迎来到前端设计专栏&#xff0c;本专栏收藏了一些好看且实用的前端作品&#xff0c;使用简单的html、css语法打造创意有趣的作品&#xff0c;为网站加入更多高级创意的元素。 html <!DOCTYPE html> <html lang"en"> <head><meta charset&quo…

SQL注入实战:http报文包讲解、http头注入

一&#xff1a;http报文包讲解 HTTP(超文本传输协议)是今天所有web应用程序使用的通信协议。最初HTTP只是一个为获取基于文本的静态资源而开发的简单协议&#xff0c;后来人们以各种形式扩展和利用它.使其能够支持如今常见的复杂分布式应用程序。HTTP使用一种用于消息的模型:客…

SpringBoot项目中集成Kaptcha

1.Kaptcha简介 Kaptcha是一个流行的Java库&#xff0c;用于生成验证码&#xff08;CAPTCHA&#xff09;图片。CAPTCHA是“Completely Automated Public Turing test to tell Computers and Humans Apart”的缩写&#xff0c;通常用于在线表单验证以防止机器人或自动化工具的滥用…

Meta 标签的力量:如何利用它们提高网站的可见性(上)

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

IOT pwn

已经过了填坑的黄金时期 环境搭建 交叉编译工具链 很多开源项目需要交叉编译到特定架构上&#xff0c;因此需要安装对应的交叉编译工具链。 sudo apt install gcc-arm-linux-gnueabi g-arm-linux-gnueabi -y sudo apt install gcc-aarch64-linux-gnu g-aarch64-linux-gnu -…

Dubbo的几个序列化方式

欢迎订阅专栏&#xff0c;会分享Dubbo里面相关的技术实现 这篇文章就不详细的介绍每种序列化方式的实现细节&#xff0c;大家可以自行去问度娘&#xff0c;我也会找一些资料。需要注意的是&#xff0c;这个先后顺序不表示性能优越 ObjectInput、ObjectOutput 这两是Dubbo序列…

看书标记【R语言数据分析项目精解:理论、方法、实战 9】

看书标记——R语言 Chapter 9 文本挖掘——点评数据展示策略9.1项目背景、目标和方案9.1.1项目背景9.1.2项目目标9.1.3项目方案1.建立评论文本质量量化指标2.建立用户相似度模型3.对用户评论进行情感性分析 9.2项目技术理论简介9.2.1评论文本质量量化指标模型1.主题覆盖量2.评论…

电脑可以连接wifi,甚至可以qq聊天,但就是不能用浏览器上网,一直显示未检测出入户网线的解决方案

今天回到家&#xff0c;准备办公却发现电脑可以连接wifi&#xff0c;甚至可以qq聊天&#xff0c;但就是不能用浏览器上网&#xff0c;一直显示未检测出入户网线的解决方案&#xff0c;小白也可以看懂 以下有几种解决方案&#xff0c;不妨都试试&#xff0c;估计可以解决95%的相…

lv14 内核定时器 11

一、时钟中断 硬件有一个时钟装置&#xff0c;该装置每隔一定时间发出一个时钟中断&#xff08;称为一次时钟嘀嗒-tick&#xff09;&#xff0c;对应的中断处理程序就将全局变量jiffies_64加1 jiffies_64 是一个全局64位整型, jiffies全局变量为其低32位的全局变量&#xff0…

web架构师编辑器内容-图层拖动排序功能的开发

新的学习方法 用手写简单方法实现一个功能然后用比较成熟的第三方解决方案即能学习原理又能学习第三方库的使用 从两个DEMO开始 Vue Draggable Next: Vue Draggable NextReact Sortable HOC: React Sortable HOC 列表排序的三个阶段 拖动开始&#xff08;dragstart&#x…

Spring-AOP入门案例

文章目录 Spring-AOP入门案例概念:通知(Advice)切入点(Pointcut )切面&#xff08;Aspect&#xff09; 目标对象(target)代理对象(Proxy)顾问&#xff08;Advisor)连接点(JoinPoint) 简单需求&#xff1a;在接口执行前输出当前系统时间Demo原始未添加aop前1 项目包结构2 创建相…

springCloud的ribbon和feign

ribbon方式调用 就是将原来的具体地址&#xff0c;改为了通过服务名去调用。注册中心中有多个服务&#xff0c;相同服务名&#xff0c;就会算作可以调用的服务。 首先得有一个注册中心&#xff0c;然后各种服务注册进去&#xff0c;然后利用ribbon或者feign去调用。 ribbon是直…

imgaug库图像增强指南(34):揭秘【iaa.Clouds】——打造梦幻般的云朵效果

引言 在深度学习和计算机视觉的世界里&#xff0c;数据是模型训练的基石&#xff0c;其质量与数量直接影响着模型的性能。然而&#xff0c;获取大量高质量的标注数据往往需要耗费大量的时间和资源。正因如此&#xff0c;数据增强技术应运而生&#xff0c;成为了解决这一问题的…

【数据结构与算法】归并排序详解:归并排序算法,归并排序非递归实现

一、归并排序 归并排序是一种经典的排序算法&#xff0c;它使用了分治法的思想。下面是归并排序的算法思想&#xff1a; 递归地将数组划分成较小的子数组&#xff0c;直到每个子数组的长度为1或者0。将相邻的子数组合并&#xff0c;形成更大的已排序的数组&#xff0c;直到最…

Python Timer定时器:控制函数在特定时间执行

Thread类有一个Timer子类&#xff0c;该子类可用于控制指定函数在特定时间内执行一次。例如如下程序&#xff1a; from threading import Timerdef hello():print("hello, world") # 指定10秒后执行hello函数 t Timer(10.0, hello) t.start() 上面程序使用 Timer …

MySQL-B-tree和B+tree区别

B-tree&#xff08;平衡树&#xff09;和Btree&#xff08;平衡树的一种变种&#xff09;是两种常见的树状数据结构&#xff0c;用于构建索引以提高数据库的查询性能。它们在一些方面有相似之处&#xff0c;但也有一些关键的区别。以下是B-tree和Btree的主要区别&#xff1a; …

Macos数据库管理软件:Navicat Premium for Mac 16.3.5中文版

Navicat Premium 16 for Mac是一款强大的数据库管理和开发工具&#xff0c;支持多种数据库系统&#xff0c;如MySQL、Oracle、SQL Server等。它提供了直观的用户界面和丰富的功能&#xff0c;使用户能够轻松地创建、管理和维护数据库。 软件下载&#xff1a;Navicat Premium fo…

【Unity学习笔记】Unity TestRunner使用

转载请注明出处&#xff1a;&#x1f517;https://blog.csdn.net/weixin_44013533/article/details/135733479 作者&#xff1a;CSDN|Ringleader| 参考&#xff1a; Input testingGetting started with Unity Test FrameworkHowToRunUnityUnitTest如果对Unity的newInputSystem感…