【深度学习|Pytorch】torchvision.datasets.ImageFolder详解

news2024/10/12 14:22:51

ImageFolder详解

  • 1、数据准备
  • 2、ImageFolder类的定义
    • transforms.ToTensor()解析
  • 3、ImageFolder返回对象

1、数据准备

创建一个文件夹,比如叫dataset,将cat和dog文件夹都放在dataset文件夹路径下:
在这里插入图片描述

2、ImageFolder类的定义

class ImageFolder(DatasetFolder):
	def __init__(
	        self,
	        root: str,
	        transform: Optional[Callable] = None,
	        target_transform: Optional[Callable] = None,
	        loader: Callable[[str], Any] = default_loader,
	        is_valid_file: Optional[Callable[[str], bool]] = None,
	    ):

可以看到,ImageFolder类有这几个参数:
root:图片存储的根目录,即存放不同类别图片文件夹的前一个路径。
transform:即对加载的这些图片进行的前处理的方式,这里可以传入一个实例化的torchvision.Compose()对象,里面包含了各种预处理的操作。
target_transform:对图片类别进行预处理,通常来说不会用到这一步,因此可以直接不传入参数,默认图像标签没有变换,如果需要进行标签的处理,同样可以传入一个实例化的torchvision.Compose()对象。
loader:表示图像数据加载的方式,通常采用默认的加载方式,ImageFolder加载图像的方式为调用PIL库,因此图像的通道顺序是RGB而非opencv的BGR
is_valid_file:获取图像文件路径的函数,并且可以检查是否有损坏的文件。
示例代码:

ROOT_TEST = 'dataset' #dataset/cat, dataset/dog
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)

transforms.ToTensor()解析

这里需要特别说一下ToTensor()这个函数的作用,刚接触深度学习的我那时以为只是单纯的将图像的ndarray和PIL格式转成Tensor格式,后来查看了一下源码之后发现,事情并没有这么简答!

   """Convert a PIL Image or ndarray to tensor and scale the values accordingly.

    This transform does not support torchscript.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
    """

这是关于ToTensor()函数的注解,这里明确指出了ToTensor()可以将PIL和ndarray格式的图像数据转成Tensor并缩放它们的值,这里的缩放他们的值的意思在下面也指出了,即将[0, 255]的像素值域归一化[0, 1.0],并且图像转换成Tensor格式之后,维度的顺序也会发生一点变化,从一开始的HWC变成了CHW的排列方式。

3、ImageFolder返回对象

以第一部分为例,我们用一个val_dataset接收了ImageFolder的返回值,那么这个Val_dataset对象里面包含了什么呢:
val_dataset.classes:存放着根目录下的子文件夹的名称(类别名称)的列表。
val_dataset.class_to_idx:存放着类别名称和各自的索引,字典类型。
val_dataset.extensions:存放着ImageFolder可以读取的图像格式名称,元组类型。
val_dataset.targets:存放着根目录下每一张图的类别索引。
val_dataset.transform:我们提供的transform的方式。
val_dataset.imgs:存放着根目录下每一张图的路径和类别索引。元组列表类型。
以上是关于这个ImageFolder返回的对象的属性的解析。

此外,我们可以通过一个for循环来遍历整个val_dataset的所有图像数据,其中val_dataset[i]是一个元组类型的数据,val_dataset[i][0]代表了前处理后的图像数据,类型为tensor,以AlexNet为例,此时的tensor应该是3 * 224 * 224的维度。val_dataset[i][1]代表了图像的类别索引。
完整示例代码:

import torch
from AlexNet import AlexNet
from torch.autograd import Variable
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# ROOT_TRAIN = 'D:/pycharm/AlexNet/data/train'
ROOT_TEST = 'dataset'

# 将图像的像素值归一化到[-1,1]之间
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)

# 如果有NVIDA显卡,转到GPU训练,否则用CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 模型实例化,将模型转到device
model = AlexNet().to(device)

# 加载train.py里训练好的模型
model.load_state_dict(torch.load(r'save_model/model_best.pth'))

# 结果类型
classes = [
    "cat",
    "dog"
]

# 把Tensor转化为图片,方便可视化
show = ToPILImage()

# 进入验证阶段
model.eval()
for i in range(10):
    x, y = val_dataset[i][0], val_dataset[i][1]
    # show():显示图片
    # show(x).show()
    # torch.unsqueeze(input, dim),input(Tensor):输入张量,dim (int):插入维度的索引,最终扩展张量维度为4维
    x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False).to(device)
    with torch.no_grad():
        pred = model(x)
        # argmax(input):返回指定维度最大值的序号
        # 得到预测类别中最高的那一类,再把最高的这一类对应classes中的那一类
        predicted, actual = classes[torch.argmax(pred[0])], classes[y]
        # 输出预测值与真实值
        print(f'predicted:"{predicted}", actual:"{actual}"')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1567324.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大日志精选案例四:某省级大数据集团日志审计优化实战解析

“在集团日常运营中,数据安全始终是我们关注的重点。过去,数据量大、处理速度慢,导致日志数据难以迅速获取和分析,影响业务决策。但自从引入聚铭大日志解决方案后,系统日志和用户行为数据都得到了高效处理与存储。该方…

SpringCloud Hystrix 服务熔断、服务降级防止服务雪崩

文章目录 SpringCloud Hystrix 熔断器、服务降级防止服务雪崩需求背景引入依赖启动类加Hystrix注解接口配置熔断常规配置超时断开错误率熔断请求数熔断限流 可配置项HystrixCommand.Setter参数Command Properties 服务降级 SpringCloud Hystrix 熔断器、服务降级防止服务雪崩 H…

网络安全基础之网络协议与安全威胁

OSI(OpenSystem Interconnect),即开放式系统互联。 一般都叫OSI参考模型,是ISO(国际标准化组织)组织在1985年研究的网络互联模型。 网络协议的简介: 定义:协议是网络中计算机或设备之间进行通信的一系列规则集合。 什么是规则?…

迅饶科技 X2Modbus 网关 GetUser 信息泄露漏洞复现

0x01 产品简介 X2Modbus是上海迅饶自动化科技有限公司Q开发的一款功能很强大的协议转换网关, 这里的X代表各家不同的通信协议, 2是T0的谐音表示转换, Modbus就是最终支持的标准协议是Modbus协议。用户可以根据现场设备的通信协议进行配置,转成标准的Modbus协议。在PC端仿真…

【考研数学】1800基础做完了,如何无缝衔接660和880❓

基础题做完,不要急着强化 首先做一个复盘,1800基础的正确率如何,如果70%以下的话,从错题入手,把掌握不扎实的地方再进行巩固,否则接下来做题的话效率会很低。 接下来考虑习题衔接的问题。 关于线代复习的…

(免费分享)基于微信小程序自助停取车收费系统

本项目的开发和制作主要采用Java语言编写,SpringBoot作为项目的后端开发框架,vue作为前端的快速开发框架,主要基于ES5的语法,客户端采用微信小程序作为开发。Mysql8.0作为数据库的持久化存储。 获取完整源码: 大家点赞…

mac | Windows 本地部署 Seata2.0.0,Nacos 作为配置中心、注册中心,MySQL 存储信息

1、本人环境介绍 系统 macOS sonama 14.1.1 MySQL 8.2.0 (官方默认是5.7版本) Seata 2.0.0 Nacos 2.2.3 2、下载&数据库初始化 默认你已经有 Nacos、MySQL,如果没有 Nacos 请参考我的文章 : Docker 部署 Nacos(单机…

基于51单片机和MAX1898的智能手机充电器设计

**单片机设计介绍,基于51单片机和MAX1898的智能手机充电器设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于51单片机和MAX1898的智能手机充电器设计概要 一、引言 随着智能手机的普及,其电池续航…

人工智能会拥有反思能力吗?

一、背景 人工智能是否能拥有真正的反思能力,目前仍在探索和发展之中。虽然现有的AI系统可以在一定程度上进行自我学习、自我调整和优化,但是它们的“反思”还远未达到人类意义上的深度和全面性。 传统的人工智能系统依赖于预设的算法和模型&#xff0c…

如何把png格式的图片转换成ico格式的图标

前言 前段时间有朋友说想要把某个图标做成图标,用来更改某个程序或者软件的图标,以求个性化! 但是下载下来的.png图片直接更改文件扩展名为.ico之后,并不能正常使用。这时候就需要有工具把.png格式转换为.ico格式。 文件更换格式…

Python Django全文搜索库之django-haystack使用详解

概要 Django Haystack库是一个用于在Django项目中实现全文搜索功能的强大工具。它集成了各种搜索引擎,如Elasticsearch、Whoosh等,为开发者提供了灵活且高效的搜索解决方案。在本文中,将深入探讨Django Haystack库的安装、配置和应用,以及如何利用其丰富的功能来实现高级全…

【问题处理】银河麒麟操作系统实例分享,理光打印机lpr协议打印问题处理

1.问题环境 系统版本:Kylin-Desktop-V10-SP1-General-Release-xxx-20221120-x86_64 内核版本:linux 5.4.18-44kt-generic 系统版本:麒麟v10 sp1 处理器:kx6640ma 2.问题描述 问题详细描述:用户通过lpr协议去连接…

《信息技术服务 智能运维 第2部分:数据治理》国家标准2024年第一次线下编写会议成功召开

2024年3月13日~15日,由运维数据治理国标编制组主办的运维数据治理国家标准2024年第一次编写工作会议在上海成功召开。 本次会议由云智慧(北京)科技有限公司承办,来自南网数字集团信通公司、太保科技、平安银行、广发银行、广东农…

Savitzky-Golay 滤波与Kalman滤波对比

分别使用SG滤波和Kalman滤波对比了平滑RTK解算的1s基线变化高斯坐标系XYH如下图: 从红框可以看出,SG滤波一定程度反应了波动情况,kalman滤波没有反映出来(PS:当然也可能和我设置参数有关,大家可以尝试&…

说一说Redis的Bitmaps和HyperLoLog?

本篇内容对应 “Redis高级数据类型”小节 和 “7.5 网站数据统计”小节 对应视频: Redis高级数据结构 网站数据统计 1 什么是UV和DAU? DAUUV英文全称Daily Active UserUnique Visotr中文全称日活跃用户量独立访客如何统计数据通过用户ID排重统计数据通…

uniapp自定义卡片轮播图

效果图 1、封装组件 <template><view><!-- 自定义卡片轮播 --><swiper class"swiperBox" :previous-margin"swiper.margin" :next-marginswiper.margin :circular"true"change"swiperChange"><swiper-ite…

CAD Plant3D 2023 下载地址及安装教程

CAD Plant3D是一款专业的三维工厂设计软件&#xff0c;用于在工业设备和管道设计领域进行建模和绘图。它是Autodesk公司旗下的AutoCAD系列产品之一&#xff0c;专门针对工艺、石油、化工、电力等行业的设计和工程项目。 CAD Plant3D提供了一套丰富的工具和功能&#xff0c;帮助…

极简云验证 download.php 文件读取漏洞复现

0x01 产品简介 极简云验证是一款开源的网络验证系统&#xff0c;支持多应用卡密生成&#xff1a;卡密生成 单码卡密 次数卡密 会员卡密 积分卡密、卡密管理 卡密长度 卡密封禁 批量生成 批量导出 自定义卡密前缀等&#xff1b;支持多应用多用户管理&#xff1a;应用备注 应用版…

找单链表倒数第K个节点

归纳编程学习的感悟&#xff0c; 记录奋斗路上的点滴&#xff0c; 希望能帮到一样刻苦的你&#xff01; 如有不足欢迎指正&#xff01; 共同学习交流&#xff01; &#x1f30e;欢迎各位→点赞 &#x1f44d; 收藏⭐ 留言​&#x1f4dd; 但行前路&#xff0c;不负韶华&#…

(学习日记)2024.03.31:UCOSIII第二十八节:消息队列常用函数

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…