【加载数据--自定义自己的Dataset类】

news2024/12/26 23:23:59

【加载数据自定义自己的Dataset类】

  • 1 加载数据
  • 2 数据转换
  • 3 自定义Dataset类
  • 4 划分训练集和测试集
  • 5 提取一批次数据并绘制样例图

假设有四种天气图片数据全部存放与一个文件夹中,如下图所示:

├─dataset2
│      cloudy1.jpg
│      cloudy10.jpg
│      cloudy100.jpg
│      cloudy101.jpg
│      cloudy102.jpg
│      cloudy103.jpg
│      cloudy104.jpg
│      cloudy105.jpg
......

1 加载数据

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
import glob
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

import glob
img_dir = r'./dataset2/*.jpg'
imgs = glob.glob(img_dir) # 读取所有图片路径
print(imgs[:3]) # 打印前3张图片

species = ['cloudy', 'rain', 'shine', 'sunrise']

species_to_idx = dict((c, i) for i, c in enumerate(species))		# 建立类别和序号字典
print(species_to_idx)

idx_to_species = dict((v, k) for k, v in species_to_idx.items())	# 反转类别和序号
print(idx_to_species)

输出如下:

['./dataset2\\cloudy1.jpg',
 './dataset2\\cloudy10.jpg',
 './dataset2\\cloudy100.jpg']
 
 {'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}

{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}

读取路径加载序号作为标签

labels = []
for img in imgs:
    for i, c in enumerate(species):
        if c in img:
            labels.append(i)

print(labels[:3])

输出如下:

[0, 0, 0]

方法1:提前划分训练集和测试集,使用乱序后的index进行划分

np.random.seed(2022)
index = np.random.permutation(count)
imgs = np.array(imgs)[index]
labels = np.array(labels, dtype=np.int64)[index]

sep = int(count*0.8)
train_imgs = imgs[ :sep]
train_labels = labels[ :sep]
test_imgs = imgs[sep: ]
test_labels = labels[sep: ]

2 数据转换

transforms = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

3 自定义Dataset类

class WT_dataset(Dataset):
    def __init__(self, imgs_path, lables):
        self.imgs_path = imgs_path
        self.lables = lables

    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        lable = self.lables[index]
        
        pil_img = Image.open(img_path)
        pil_img = pil_img.convert("RGB")
        pil_img = transforms(pil_img)
        return pil_img, lable

    def __len__(self):
        return len(self.imgs_path)

# 加载数据
dataset = WT_dataset(imgs, labels)

4 划分训练集和测试集

count = len(dataset)
print(count)

# 方法2:划分训练集和测试集
train_count = int(0.8*count)
test_count = count - train_count
train_dataset, test_dataset = data.random_split(dataset, [train_count, test_count])
print(len(train_dataset), len(test_dataset))

# 批量加载数据
BTACH_SIZE = 16
train_dl = torch.utils.data.DataLoader(
                                       train_dataset,
                                       batch_size=BTACH_SIZE,
                                       shuffle=True
)

test_dl = torch.utils.data.DataLoader(
                                       test_dataset,
                                       batch_size=BTACH_SIZE,
)

5 提取一批次数据并绘制样例图

imgs, labels = next(iter(train_dl))	#提取一批次数据
print(imgs.shape)
im = imgs[0].permute(1, 2, 0)	# 将通道所在列放在后
print(im.shape)


plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):
    img = (img.permute(1, 2, 0).numpy() + 1)/2
    plt.subplot(2, 3, i+1)
    plt.title(idx_to_species.get(label.item()))
    plt.imshow(img)
plt.savefig('pics/example1.jpg', dpi=400)

输出如下:

torch.Size([16, 3, 96, 96])

torch.Size([3, 96, 96])

torch.Size([96, 96, 3])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1040190.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用matlab产生二维动态曲线视频文件具体举例

使用matlab产生二维动态曲线视频文件举例 在进行有些函数变化过程时候,需要用到直观的动态显示,本博文将举例说明利用Matlab编程进行二维动态曲线的生成视频文件。 一、问题描述 利用matlab编程实现 y 1 s i n ( t ) , y 2 c o s ( t ) , y 3 s i …

JAVA_多线程的实现方式

线程的状态 方式一&#xff1a; public class Thread1 extends Thread {Overridepublic void run() {synchronized (this) {for (int i 0; i < 100; i) {System.out.println(getName() "" i);}}} } Thread1 thread1 new Thread1(); thread1.start(); 方式二…

用flex实现grid布局

1. css代码 .flexColumn(columns, gutterSize) {display: flex;flex-flow: row wrap;margin: calc(gutterSize / -2);> div {flex: 0 0 calc(100% / columns);padding: calc(gutterSize / 2);box-sizing: border-box;} }2.用法 .grid-show-item3 {width: 100%;display: fl…

天洑软件祝贺“星耀蓉城 篮能可贵”四城邀请赛暨西北工业大学校友篮球嘉年华活动圆满举办

9月16日&#xff0c;由西北工业大学成都校友会主办&#xff0c;成都市新都区文化体育和旅游局承办&#xff0c;天洑软件、泸州老窖、陕西省篮球协会、龙湖地产、奥伦达科技、南京全信、成都鑫长源、四川银行、优仿科技、华科机电、小鱼易连、四川省篮球协会、GAME7篮球俱乐部赞…

云安全之访问控制介绍

访问控制技术背景 信息系统自身的复杂性、网络的广泛可接入性等因素&#xff0c;系统面临日益增多的安全威胁&#xff0c;安全问题日益突出&#xff0c;其中一个重要的问题是如何有效地保护系统的资源不被窃取和破坏。 访问控制技术内容包括访问控制策略、访问控制模型、访问…

virtualbox无界面打开linux虚拟机的bat脚本,以及idea(代替Xshell)连接linux虚拟机的方法

virtualbox无界面打开linux虚拟机的bat脚本&#xff0c;以及idea连接linux虚拟机的方法 命令行运行代码成功运行的效果图 idea连接linux虚拟机的方法【重要】查看虚拟机的IP地址idea中选择菜单&#xff08;该功能可代替Xshell软件&#xff09;配置设置连接成功进入idea中的命令…

MySQL数据库入门到精通9--运维篇

1. 日志 1.1 错误日志 错误日志是 MySQL 中最重要的日志之一&#xff0c;它记录了当 mysqld 启动和停止时&#xff0c;以及服务器在运行过程中发生任何严重错误时的相关信息。当数据库出现任何故障导致无法正常使用时&#xff0c;建议首先查看此日志。 该日志是默认开启的&am…

解决安装 RabbitMQ 安装不成功的问题

由于RabbitMQ是基于erlang的&#xff0c;所以&#xff0c;在正式安装RabbitMQ之前&#xff0c;需要先安装一下erlang。 1、下载mq https://www.rabbitmq.com/download.html 2、下载erlang&#xff08;点击下载路径根据下载的MQ版本对应下载erl版本&#xff09; https://www.…

客户端负载均衡_什么是负载均衡

为什么需要负载均衡 俗话说在生产队薅羊毛不能逮着一只羊薅&#xff0c;在微服务领域也是这个道理。面对一个庞大的微服务集群&#xff0c;如果你每次发起服务调用都只盯着那一两台服务器&#xff0c;在大用户访问量的情况下&#xff0c;这几台被薅羊毛的服务器一定会不堪重负…

澳大利亚新版《2023年消费品(36个月以下儿童玩具) 安全标准》发布 旨在降低危险小零件的伤害

2023年9月4日&#xff0c;澳大利亚政府发布了新的儿童玩具强制性安全标准《2023年消费品(36个月以下儿童玩具)安全标准》&#xff08;Consumer Goods (Toys for Children up to and including 36 Months of Age) Safety Standard 2023&#xff09;。该强制性标准旨在尽可能地降…

2023/9/24总结

Redis Redis 是一个基于内存的键值数据库 安装 Window下Redis的安装和部署详细图文教程&#xff08;Redis的安装和可视化工具的使用&#xff09;_redis安装_明金同学的博客-CSDN博客 出现上面代表安装成功了 redis 一共有 16 个库 安装后 再安装图形化界面 图形界面十分方便…

Java————网络编程

一 、网络编程基础 1. 为什么需要网络编程 用户在浏览器中&#xff0c;打开在线视频网站&#xff0c; 如优酷看视频&#xff0c;实质是通过网络&#xff0c; 获取到网络上的一个视频资源。 与本地打开视频文件类似&#xff0c;只是视频文件这个资源的来源是网络。 相比本地资…

ubuntu20.04部署ntp服务器ntpd(ntpdate )

文章目录 步骤1. 安装NTP2. 配置NTP3. 重启NTP服务4. 检查NTP服务状态5. 验证NTP同步ntpq -p检查本地ntp服务是否正常服务器不能连外网&#xff0c;如何配置&#xff1f; ntpdate -q xxx查询ntp服务器时间 步骤 1. 安装NTP 首先&#xff0c;在终端中更新你的包列表&#xff0…

vue项目 H5 动态设置浏览器标题

1&#xff0c;先将要展示的标题存本地 if (that.PromotionInfo.Title) {localStorage.setItem("AcTitle", that.PromotionInfo.Title)} 2,现在路由meta中设置标题&#xff0c;再在路由守卫中设置 import Vue from vue import Router from vue-router import prom…

游戏录屏软件推荐,教你录制高清游戏视频

“有没有好用的游戏录屏软件推荐呀&#xff0c;最近当上了游戏主播&#xff0c;平台要求每天都要发一个游戏视频&#xff0c;可是我的游戏录屏软件太拉胯了&#xff0c;录制出来的视频非常糊&#xff0c;导致平台审核不通过&#xff0c;所以想问问大家有没有游戏录屏软件推荐一…

机器视觉检测在流水线上的技术应用

机器视觉在流水线上的应用机器视觉系统的主要功能可以简单概括为&#xff1a;定位、识别、测量、缺陷检测等。相对于人工或传统机械方式而言&#xff0c;机器视觉系统具有速度快、精度高、准确性高等一系列优点。随着工业现代化发展&#xff0c;机器视觉已经广泛应用于各大领域…

【Git】轻松学会 Git:实现 Git 的分支管理

文章目录 前言一、对分支的理解二、分支的创建三、分支的切换3.1 切换到 dev 分支3.2 在 dev 分支上进行文件的修改和提交3.2 来回切换 master 和 dev 分支&#xff0c;查看修改的内容 四、分支的合并五、分支的删除六、冲突的合并6.1 模拟制造冲突6.2 解决冲突 七、分支管理策…

openGauss学习笔记-79 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT应用场景

文章目录 openGauss学习笔记-79 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT应用场景79 MOT应用场景 openGauss学习笔记-79 openGauss 数据库管理-内存优化表MOT管理-内存表特性-MOT应用场景 本节介绍了openGauss内存优化表&#xff08;Memory-Optimized Table&am…

Java基于基于微信小程序的快递柜管理系统

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 文章目录 第一章&#xff1a;简介第二章、***\*开发环境&#xff1a;\******后端&#xff1a;****前端&am…