pytorch简单自定义Datasets

news2024/11/24 9:17:08

前言

  • 本文记录一下如何简单自定义pytorch中Datasets,官方教程
  • 文件层级目录如下:
    • images
      • 1.jpg
      • 2.jpg
      • 9.jpg
    • annotations_file.csv

数据说明

  • image文件夹中有需要训练的图片,annotations_file.csv中有2列,分别为image_idlabel,即图片名和其对应标签。
    在这里插入图片描述
image_idlabel
1风景
2风景
3风景
4星空
5星空
6星空
7人物
8人物
9人物

代码展示

导入必要包

import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as T

自定义Datasets

  • 自定义Datasets之前,首先我们需要准备两个信息:
    • 图片地址
    • 图片标签
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 读取包含图片id和图片标签的csv文件
        self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 读取图片标签
        label = self.img_labels.iloc[idx, 1]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

定义图片预处理方法

transform = {'train':T.Compose([
    # 将图片缩放至固定尺寸
    T.Resize((224,224)),
    # 应用CIFAR10自动增强策略
    T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
    # 像素值归一化,并转换为ternsor格式
    T.ToTensor(),])}
  • 接下来我们将Datasets实例化
annotations_file = '/kaggle/input/datasets-test/annotations_file.csv'
img_dir = '/kaggle/input/datasets-test/images'
train_data = CustomImageDataset(annotations_file = annotations_file ,img_dir = img_dir, transform = transform['train'])
  • 我们可以使用len(train_data)检查样本完整性,以及Datasets定义正确性,这里输出9,的确只有9张图片,正确无误。

使用DataLoaders加载数据

  • 因为这里数据较少,所以设置batch_size = 2,打乱数据shuffle = true,不丢弃数据drop_last=False,有关DataLoader的更多操作可以参照官方API
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
  • 使用iter()函数和next()函数,取1个batch,检查数据
train_features, train_labels = next(iter(train_dataloader))
  • 取第1个batch中的第1个图片,并将其可视化。
  • 由于train_features[0]的维度为(1,3,224,224),所以使用squeeze()函数从数组中删除单维度条目,即把为1的维度去掉。再使用permute()函数将维度变换(224,224,3),便于plt绘图。
# 维度整理
img = train_features[0].squeeze()
# 转置操作
img = img.permute(1,2,0)
# 绘图
plt.imshow(np.asarray(img))
# 关闭刻度线
plt.axis('off')
plt.show()

请添加图片描述

  • 打印图片标签print(train_labels[0]),输出'人物'

在Datasets中将字符串标签数值化

  • 我们发现上面打印出的标签为字符串,如果我们想要将其数值化,只需要在Datasets__getitem__部分改动一点
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 读取包含图片id和图片标签的csv文件
        self.img_labels = pd.read_csv(annotations_file, encoding="utf-8")
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.img_labels.iloc[idx, 0]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 将字符串进行数值编码
        data_category, data_class = pd.factorize(self.img_labels.iloc[:, 1])
        label = data_category[idx]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • 这样就可以了,可以看到其实array格式的数据也是可以读取的,只要保证idx一致,且对应就可以。

划分训练集与验证集

  • 根据前面的说明,其实训练集与验证集的划分就变的很简单了,只需要4个列表/数组,train_pathtrain_labelvaild_pathvaild_label分别表示训练集图片路径、标签、验证集图片路径、标签。DataSets可以这样写:
class CustomImageDataset(Dataset):
    def __init__(self, image_id, image_label, img_dir, transform=None, target_transform=None):
        # 读取包含图片id
        self.image_id = image_id
        # 读取图片标签
        self.image_label = image_label
        # 图片存储路径
        self.img_dir = img_dir
        # 对图片进行预处理
        self.transform = transform
        # 对图片标签进行预处理,如One-hot编码
        self.target_transform = target_transform

    def __len__(self):
        # 返回样本总量
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼接图片完整读取路径
        img_path = os.path.join(self.img_dir, str(self.image_id[idx]) + '.jpg')
        # 使用PIL库读取图片
        image = Image.open(img_path)
        # 读取图片标签
        label = self.image.label[idx]
        # 如果传入了图片预处理方法则执行
        if self.transform:
            image = self.transform(image)
        # 如果传入了标签预处理方法则执行
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
  • 实例化和加载器可以写作:
transform = {'train':T.Compose([T.Resize((224,224)),T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),T.ToTensor(),]),
             'valid':T.Compose([T.Resize((224,224)),T.ToTensor(),])}
             
# 实例化训练数据
train_data = CustomImageDataset(train_path, train_label, img_dir = './',transform = transform['train'])
# 实例化训练数据
valid_data = CustomImageDataset(valid_path, valid_label, img_dir = './',transform = transform['valid'])

# 训练集数据加载器
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, drop_last=False)
# 验证集数据加载器
valid_dataloader = DataLoader(valid_data, batch_size=2, shuffle=True, drop_last=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/157566.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python的23种设计模式(完整版带源码实例)

作者:虚坏叔叔 博客:https://xuhss.com 早餐店不会开到晚上,想吃的人早就来了!😄 Python的23种设计模式 一 什么是设计模式 设计模式是面对各种问题进行提炼和抽象而形成的解决方案。这些设计方案是前人不断试验&…

【入门篇】1 # 复杂度分析(上):如何分析、统计算法的执行效率和资源消耗?

说明 【数据结构与算法之美】专栏学习笔记。 什么是复杂度? 复杂度也叫渐进复杂度,包括时间复杂度和空间复杂度,用来分析算法执行效率与数据规模之间的增长关系,可以粗略地表示,越高阶复杂度的算法,执行…

时域脉冲通信采用高斯脉冲且使用PAM调制的Matlab简易演示仿真

时域脉冲通信采用高斯脉冲且使用PAM调制的Matlab简易演示仿真 环境 matlab 2016a 指标 1 将声音信号转为二进制码 2 PAM调制 3 采用高斯脉冲 流程 代码 [OriginVoice,fs]audioread(voice.m4a) ; OriginVoiceOriginVoice(:,2); Nlength(OriginVoice); % 计算信号x的长度 …

算法训练营 day15 二叉树 层序遍历 翻转二叉树 对称二叉树

算法训练营 day15 二叉树 层序遍历 翻转二叉树 对称二叉树 层序遍历 102. 二叉树的层序遍历 - 力扣(LeetCode) 给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。…

标签设计打印软件:LabelJoy 6.23.0 Crack

LabelJoy 专业条码软件 生成25种条形码 从数据源导入条码 计算自动校验 商业条形码标签软件 兼容 Excel、Access、MySQL、Oracle 11.000 个预装的纸张布局 支持任何打印机 通过 3 个步骤创建和打印标签: 选择布局 创建您的标签 开始打印 最好的标签打印软件&#xf…

kafka-1

文章目录1.启动2.创建主题3.发送消息4.消费消息5.使用kafka connect将现有的数据导入到kafka中6.使用kafka streams处理kafka中的events6.终止服务集群配置要点创建主题要点主题分区变更主题副本可变更吗?创建生产者要点> tar -xzf kafka_2.12-3.3.1.tgz1.启动 …

Mac生成和查看ssh key

从 git 上拉取或者提交代码每次都需要输入账号密码,这样很麻烦。我们可以在电脑上生成一个 ssh key,然后把ssh key添加到 git 中,就可以不用每次去输账号密码了。下面就介绍一下怎么在自己的 Mac 中生成和查看 ssh key。一、Mac生成SSH Key打…

【环境】idea远程debug

工作中,遇到问题的时候,想知道上下文中对应的参数值是什么?这时候,1、我们可以全靠逻辑分析。费脑,不一定对。2、打印日志,打印的信息不一定全,换包,麻烦3、远程debug。 1、配置ide…

pytorch二维码识别

二维码图片的生成 利用captcha可以生成二维码图片 # -*- coding: UTF-8 -*- from captcha.image import ImageCaptcha # pip install captcha from PIL import Image import random import time import os # 验证码中的字符 # string.digits string.ascii_uppercase NUMBER…

整理了一周近万字讲解linux基础开发工具vim,gdb,gcc,yum等的使用

文章目录 前言一、yum的使用二、vim的使用三 . gcc/g的使用四 . gdb的使用总结前言 想用linux开发一些软件等必须要会的几种开发工具是必不可少的,在yum vim gcc gdb中指令繁杂的是vim和gdb这两个工具,至于yum和gcc的指令就比较简单了。 一、yum的使用…

【SpringMVC】拦截器

目录 一、概念 二、自定义拦截器的三个实现方法 三、自定义拦截器执行流程 四、使用 五、拦截器和过滤器 相关文章(可以关注我的SpringMVC专栏) SpingMVC专栏SpingMVC专栏一、概念在学习拦截器之前,我们得先了解一下它是个什么❓ SpringMVC…

SAP ABAP调用标准事务码

这里介绍常见的几种在开发中常用到的事务代码跳转功能。 1、最常用到的是“SET PARAMETER”语句赋值,然后再使用“CALL TRANSACTION”语句跳转屏幕。 比如采购订单、销售订单、交货单、采购发票、销售发票等事务代码,均可以利用给参数赋值来直接跳转&am…

零售及仓储数字化整理解决方案

价格管控 皮克价格管控方案可实现门店与企业信息管理平台的数据同步,强化零售企业对终端的控制。同时为企业销售决策提供支持,优化门店经营活动的效率和频率。陈列管理 皮克陈列管理方案通过电子价签产品使商品陈列得到固化。 同时实现了陈列可视化&am…

ArcGIS水文分析提取河网及流域

在进行某些研究或者一些论文插图显示的时候,有时我们会碰到在部分资料中找不到一些小的河流或者流域的数据的情况,这里讲述通过DEM数据生成河网及流域。 一、数据来源 四川省高程数据来源于中国科学院资源环境科学与数据中心(中国科学院资源环…

Vue3学习之深度剖析CSS Modules和Scope

Css Modules 是通过对标签类名进行加装成一个独一无二的类名,比如.class 转换成.class_abc_123,类似于symbol,独一无二的键名 Css Scope 是通过为元素增加一个自定义属性,这个属性加上独一无二的编号,而实现作用域隔离。 原理 …

爬虫必备抓包工具——Fiddler【认识使用】

目录:1.fiddler (抓包工具)1.1 引入:HTTP/https代理(正向代理)1.2 拓展:反向代理:1.2 初识Fiddler①什么是抓包?抓包有什么用?②浅谈fiddler:③fi…

Unity_Skybox自定义插件可实现日夜更替Polyverse Skies | Low Poly

又又一个天空盒,不过这个做的还是比较完善的。。。不会出现买家秀和买家秀差别大问题 此Skybox插件特色提供: 可扩展,自定义很多的Skybox Shader预制体几个,虽然都是夜晚样式(缺白天)若干预设值</

对NIO和BIO的进一步理解

疑问 在之前的学习中&#xff0c;只提到BIO是阻塞IO&#xff0c;在建立连接和读写事件时会阻塞线程。NIO是非阻塞IO&#xff0c;基于事件注册&#xff0c;通过Selector进行切换Channel&#xff0c;不会阻塞线程。对于这种解释&#xff0c;还是带有一些疑问的。Selector进行Cha…

#define 实现快捷模板类实例在eigen::Maxtrix中的应用

欢迎关注更多精彩 关注我&#xff0c;学习常用算法与数据结构&#xff0c;一题多解&#xff0c;降维打击。 背景 在eigen库中&#xff0c;矩阵类原来的用法是 Matrix<Type, row, col>。 为了方便用户&#xff0c;库中还提供了用户常用的快捷类型&#xff0c;比如Matrix…

Java-String的API

一、length()package 做题; import java.lang.reflect.Array; import java.security.PublicKey; import java.util.Arrays; import java.util.Scanner;import javax.naming.StringRefAddr;public class Main {public static void main(String[] args) {Scanner sc new Scanne…