CIFAR-100数据集的加载和预处理教程

news2025/1/13 9:28:26

一、CIFAR-100数据集介绍

CIFAR-100(Canadian Institute for Advanced Research - 100 classes)是一个经典的图像分类数据集,用于计算机视觉领域的研究和算法测试。它是CIFAR-10数据集的扩展版本,包含了更多的类别,用于更具挑战性的任务。

CIFAR-100包含了100个不同的类别,每个类别都包含600张32x32像素的彩色图像。

这100个类别被划分为20个大类别,每个大类别包含5个小类别。这个层次结构使得数据集更加丰富,包含了各种各样的对象和场景。每张图像的大小是32x32像素,包含RGB三个通道。

用途: CIFAR-100常被用于评估图像分类算法的性能。由于图像分辨率相对较低,它在实际中可能不太适用于一些复杂的计算机视觉任务,但对于学术研究和算法开发而言是一个常见的基准数据集。

二、下载并加载CIFAR-100数据集

import torch
from torch.utils.data import Dataset,DataLoader
import torchvision
import torchvision.transforms as transforms

def get_train_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True,
                                                      transform=transform_train)
    cifar100_training_loader = DataLoader(
        cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_training_loader

def get_val_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_test_loader

这里我们采用的是torchvision下载CIFAR-100数据集并将其保存到指定的路径,定义这两个函数 get_train_loader 和 get_val_loader 分别用于获取训练集和验证集的数据加载器,并进行了预处理和增强的操作。

三、检测数据加载情况

博主曾经在这上面吃过很多的亏,一般我们遇到维度不匹配的情况,通常会认为是网络的问题,但我会告诉你也有可能是数据加载的部分,这种开源数据集还好,我们项目上用的是自制的数据集,它的图片可能真的就是有些问题,比如你明明是用PIL加载图片,按理来说应该就是三通道无疑才对,但事实是就是存在通道为1的情况。

所以,为了让我们具备严谨的工程能力,为将来自己的项目打下基础,哪怕是开源数据集,我们也要进行测试。

一般来说,主要看到就是它的维度是否是正确的,还有它是否能够正确的显示。

在上面我们进行预处理操作,所以应该先进行反归一化:

def denormalize(tensor, mean, std):
    """反归一化操作,将归一化后的张量转换回原始范围."""
    if not torch.is_tensor(tensor):
        raise TypeError("Input should be a torch tensor.")

    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)

    return tensor

而要看如何正常的显示,我们当然不希望单张的显示,这样似乎太慢了,所以这里我们按照批量大小进行显示:

def show_batch(images, labels):
    import matplotlib
    matplotlib.use('TkAgg')
    images = denormalize(images, mean, std)
    img_grid = make_grid(images, nrow=4, padding=10, normalize=True)
    plt.imshow(img_grid.permute(1, 2, 0))
    plt.title(f"Labels: {labels}")
    plt.show()

测试代码:

if __name__=="__main__":
    import matplotlib.pyplot as plt
    from torchvision.utils import make_grid

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    def denormalize(tensor, mean, std):
        """反归一化操作,将归一化后的张量转换回原始范围."""
        if not torch.is_tensor(tensor):
            raise TypeError("Input should be a torch tensor.")

        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)

        return tensor

    mean = CIFAR100_TRAIN_MEAN
    std = CIFAR100_TRAIN_STD

    test_loader = get_val_loader(mean, std, batch_size=16, num_workers=2, shuffle=False)

    def show_batch(images, labels):
        import matplotlib
        matplotlib.use('TkAgg')
        images = denormalize(images, mean, std)
        img_grid = make_grid(images, nrow=4, padding=10, normalize=True)
        plt.imshow(img_grid.permute(1, 2, 0))
        plt.title(f"Labels: {labels}")
        plt.show()

    for images, labels in test_loader:
        show_batch(images, labels)
        # print(images.size(), labels)

最后两行就是图片批量显示与维度检测的测试,这里最好是单独的测试,即两行中一行注释,一行正常运行。

四、自定义CIFAR-100的dataset类

dataset类的以下几个要点:

  • dataset类需要继承import torch.utils.data.dataset。
  • dataset的作用是将任意格式的数据,通过读取、预处理或数据增强后以tensor的形式输出。其中任意格式的数据指可能是以文件夹名作为类别的形式、或以txt文件存储图片地址的形式。而输出则指的是经过处理后的一个 batch的tensor格式数据和对应标签。
  • dataset类需要重写的主要有三个函数要完成:__init__函数、__len__函数和__getitem__函数。
  1. __init__(self, ...) 函数:初始化数据集。在这里,你通常会加载数据,设置转换(transformations)等。这个函数在数据集创建时调用。

  2. __len__(self)函数:返回数据集的大小,即数据集中样本的数量。这个函数在调用len(dataset) 时调用。

  3. __getitem__(self,index)函数:根据给定的索引返回数据集中的一个样本。这个函数允许你通过索引访问数据集中的单个样本,以便用于模型的训练和评估。

import os
import pickle
import numpy as np

from torch.utils.data import Dataset,DataLoader

class CIFAR100Dataset(Dataset):
    def __init__(self, path, transform=None, train=False):
        if train:
            sub_path = 'train'
        else:
            sub_path = 'test'
        with open(os.path.join(path, sub_path), 'rb') as cifar100:
            self.data = pickle.load(cifar100, encoding='bytes')
        self.transform = transform

    def __len__(self):
        return len(self.data['fine_labels'.encode()])

    def __getitem__(self, index):
        label = self.data['fine_labels'.encode()][index]
        r = self.data['data'.encode()][index, :1024].reshape(32, 32)
        g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)
        b = self.data['data'.encode()][index, 2048:].reshape(32, 32)
        image = np.dstack((r, g, b))

        if self.transform:
            image = self.transform(image)
        return image, label

测试代码:

if __name__=="__main__":
    mean = CIFAR100_TRAIN_MEAN
    std = CIFAR100_TRAIN_STD

    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    train_dataset = CIFAR100Dataset(path='./data/cifar-100-python', transform=transform_train)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    for images, labels in train_loader:
        show_batch(images, labels)
        # print(images.size(), labels)

附录

本章节源码

import torch
from torch.utils.data import Dataset,DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import os
import pickle
import numpy as np

CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

__all__ = ["get_train_loader", "get_val_loader", "CIFAR100Dataset"]

class CIFAR100Dataset(Dataset):
    def __init__(self, path, transform=None, train=False):
        if train:
            sub_path = 'train'
        else:
            sub_path = 'test'
        with open(os.path.join(path, sub_path), 'rb') as cifar100:
            self.data = pickle.load(cifar100, encoding='bytes')
        self.transform = transform

    def __len__(self):
        return len(self.data['fine_labels'.encode()])

    def __getitem__(self, index):
        label = self.data['fine_labels'.encode()][index]
        r = self.data['data'.encode()][index, :1024].reshape(32, 32)
        g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)
        b = self.data['data'.encode()][index, 2048:].reshape(32, 32)
        image = np.dstack((r, g, b))

        if self.transform:
            image = self.transform(image)
        return image, label

class CIFAR100Test(Dataset):
    def __init__(self, path, transform=None):
        with open(os.path.join(path, 'test'), 'rb') as cifar100:
            self.data = pickle.load(cifar100, encoding='bytes')
        self.transform = transform

    def __len__(self):
        return len(self.data['data'.encode()])

    def __getitem__(self, index):
        label = self.data['fine_labels'.encode()][index]
        r = self.data['data'.encode()][index, :1024].reshape(32, 32)
        g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)
        b = self.data['data'.encode()][index, 2048:].reshape(32, 32)
        image = np.dstack((r, g, b))

        if self.transform:
            image = self.transform(image)
        return image, label

def get_train_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True,
                                                      transform=transform_train)
    cifar100_training_loader = DataLoader(
        cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_training_loader

def get_val_loader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_test_loader

def show_batch(images, labels):
    import matplotlib
    matplotlib.use('TkAgg')
    images = denormalize(images, CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    img_grid = make_grid(images, nrow=4, padding=10, normalize=True)
    plt.imshow(img_grid.permute(1, 2, 0))
    plt.title(f"Labels: {labels}")
    plt.show()

def denormalize(tensor, mean, std):
    """反归一化操作,将归一化后的张量转换回原始范围."""
    if not torch.is_tensor(tensor):
        raise TypeError("Input should be a torch tensor.")

    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)

    return tensor

def main1():
    test_loader = get_val_loader(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD, batch_size=16, num_workers=2, shuffle=False)
    for images, labels in test_loader:
        show_batch(images, labels)
        # print(images.size(), labels)

if __name__=="__main__":
    
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])
    train_dataset = CIFAR100Dataset(path='./data/cifar-100-python', transform=transform_train)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    for images, labels in train_loader:
        show_batch(images, labels)
        # print(images.size(), labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1198144.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Semantic Kernel 学习笔记1

1. 挂代理跑通openai API 2. 无需魔法跑通Azure API 下载Semantic Kernel的github代码包到本地,主要用于方便学习python->notebooks文件夹中的内容。 1. Openai API:根据上述文件夹中的.env.example示例创建.env文件,需要填写下方两个内…

51单片机应用从零开始(一)

1. 单片机在哪里 单片机是一种集成电路芯片,通常被嵌入到电子设备中用于控制和处理数据,例如家电、汽车、电子玩具、智能家居等。因此,你可以在许多电子设备中找到单片机的存在。单片机通常被放置在设备的主板或控制板上。 2. 单片机是什么…

【C语言】冒泡排序(图解)

🌈write in front :🔍个人主页 : 啊森要自信的主页 🌈作者寄语 🌈: 小菜鸟的力量不在于它的体型,而在于它内心的勇气和无限的潜能,只要你有决心,就没有什么事情是不可能的…

WebRTC简介及使用

文章目录 前言一、WebRTC 简介1、webrtc 是什么2、webrtc 可以做什么3、数据传输需要些什么4、SDP 协议5、STUN6、TURN7、ICE 二、WebRTC 整体框架三、WebRTC 功能模块1、视频相关①、视频采集---video_capture②、视频编解码---video_coding③、视频加密---video_engine_encry…

AI:83-基于深度学习的手势识别与实时控制

🚀 本文选自专栏:人工智能领域200例教程专栏 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 ✨✨✨ 每一个案例都附带有在本地跑过的代码,详细讲解供大家学习,希望可以帮到大家。欢迎订阅支持,正在不断更新中,…

ArcGIS:如何迭代Shp文件所有要素并分别导出为Shp文件?

01 前言 尝试用IDL实现,奈何又涉及新的类IDLffShape,觉得实在没有必要学习的必要,毕竟不是搞开发,只是做做数据处理,没必要拿IDL不擅长的且底层的东西自己造轮子。 这里想到使用Python去解决,gdal太久没用…

安装包 amd,amd64, arm,arm64 都有什么区别

现在的安装包也不省心,有各种版本都不知道怎么选。 根据你安装的环境配置。 amd: 32位X86 amd64: 64位X86 arm: 32位ARM arm64: 64位ARM amd64是X86架构的CPU,64位版。amd64又叫X86_64。主流的桌面PC&am…

25期代码随想录算法训练营第十四天 | 二叉树 | 递归遍历、迭代遍历

目录 递归遍历前序遍历中序遍历后序遍历 迭代遍历前序遍历中序遍历后序遍历 递归遍历 前序遍历 # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # …

Javaweb之javascript的DOM对象的详细解析

1.5.3 DOM对象 1.5.3.1 DOM介绍 DOM:Document Object Model 文档对象模型。也就是 JavaScript 将 HTML 文档的各个组成部分封装为对象。 DOM 其实我们并不陌生,之前在学习 XML 就接触过,只不过 XML 文档中的标签需要我们写代码解析&#x…

Python---split()方法 + join()方法

split()方法 split 英 /splɪt/ v. 分裂,使分裂(成不同的派别);分开,使分开(成为几个部份);(使)撕裂;分担,分享;划破&…

VueRequest——管理请求状态库

文章目录 前言一、为什么选择 VueRequest?二、使用步骤1.安装2.用例 前言 VueRequest——开发文档 VueReques——GitHub地址 在以往的业务项目中,我们经常会被 loading 状态的管理、请求的节流防抖、接口数据的缓存、分页等重复的功能实现所困扰。每次开…

积极应对云网络安全

以下是 IT 领导者需要了解的内容,才能在云网络安全方面占据上风。 如果您的组织尚未主动解决云网络安全问题,则将面临灾难的风险。等待攻击发生根本没有意义。 主动云安全会采取积极措施来发现潜在威胁并在网络攻击发生之前阻止网络攻击。 这是通过持…

【深度挖掘Java性能调优】「底层技术原理体系」深入挖掘和分析如何提升服务的性能以及执行效率(引导篇)

深入挖掘和分析如何提升服务的性能以及执行效率 前提介绍知识要点 性能概述教你看懂程序的性能案例介绍性能指标性能的参考指标性能瓶颈(木桶原理) 性能分析三大定律Amdahl定律计算公式参数解释案例分析定律总结 Gustafson定律与Amdahl定律相对立Gustafs…

C#中的扩展方法---Extension

C#中扩展方法是C# 3.0/.NET 3.x 新增特性,能够实现向现有类型中“添加”方法,以下主要介绍C#中扩展方法的声明及使用。 1、扩展方法的声明 扩展方法使能够向现有类型“添加”方法,而无需创建新的派生类型、重新编译或以其他方式修改原始类型…

如何知道一个程序为哪些信号注册了哪些信号处理函数?

https://unix.stackexchange.com/questions/379694/is-there-a-way-to-know-if-signals-are-present-in-your-application-and-which-sign 使用 strace

《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》阅读笔记

论文标题 《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》 Swin 这个词貌似来自后面的 Shifted WindowsShifted Windows:移动窗口Hierarchical:分层 作者 微软亚洲研究院出品 初读 摘要 提出 Swin Transformer 可以…

Spark的执行计划

Spark 3.0 大版本发布,Spark SQL 的优化占比将近 50%。Spark SQL 取代 Spark Core,成为新一代的引擎内核,所有其他子框架如 Mllib、Streaming 和 Graph,都可以共享 Spark SQL 的性能优化,都能从 Spark 社区对于 Spark …

Java自学第10课:JavaBean和servlet基础

目录 目录 1 JavaBean (1)概念 (2)分类 (3)使用 2 servlet (1)代码结构 (2)常用接口 (3)如何开发 1 新建servlet 2 配置 1…

索尼RSV文件怎么恢复为MP4视频

索尼相机RSV是什么文件? 如果您的相机是索尼SONY A7S3,A7M4,FX3,FX3,FX6,或FX9等,有时录像会产生一个RSV文件,而没有MP4视频文件。RSV其实是MP4的前期文件,经我对RSV文件…

nodejs+vue+python+PHP+微信小程序-安卓- 基于小程序的高校后勤管理系统-计算机毕业设计

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…