(机器学习深度学习常用库、框架|Pytorch篇)第三节:Pytorch之torchvision详解

news2025/1/3 3:36:55

文章目录

  • 一:torchvision概述
  • 二:torchvision.datasets
    • (1)官方数据集
    • (2)自定义数据集类
    • (3)ImageFolder手动实现
  • 三:torchvision.transforms
  • 四:torchvision.models

一:torchvision概述

torchvisiontorchvision是Pytorch的一个图形库,主要用来构建计算机视觉模型,torchvision由以下四个部分构成

  • torchvision.datasets:包括一些加载数据的函数和常用的数据集接口
  • torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、ResNet等等
  • torchvision.transforms:包含一些常见的图片变换,例如裁剪、旋转等等
  • torchvision.utils:其他用法

二:torchvision.datasets

torchvision.datasets:该模块下既有官方提供的数据集,也有自定义数据集的类,两者都是torch.utils.data.Dataset的子类,因此可以直接输入到torch.utils.data.DataLoader中去

(1)官方数据集

torchvision.datasets中提供的官方数据如下,这些数据集详细介绍见此文:数据集介绍

MNIST
Fashion-MNIST
KMNIST
EMNIST
FakeData
COCO
Captions
Detection
LSUN
​ImageFolder
DatasetFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour
SBU
Flickr
VOC
Cityscapes
...

这里我们以MNIST数据集为例,演示一下这些官方数据集如何加载,其余数据集的加载和MNIST一致

如下,使用torchvision.datasets.MNIST加载MNIST数据集

train_data = dataset.MNIST(root='./mnist/',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)
test_data = dataset.MNIST(root='./mnist/',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=False)
  • root:表示数据集待存放的目录
  • train:如果为true将会使用训练集的数据集(training.pt),如果为false将会使用测试集数据集(test.pt
  • download:如果为true将会从网络上下载并放入root中,如果数据集已下载则不会再次下载
  • transform:接受PIL图片并返回转换后的图片,常用的就是转换为tensor(这里便会调用torchvision.transform

数据集加载成功后,文件布局如下

在这里插入图片描述

(2)自定义数据集类

这里的自定义数据集类指的主要是torchvision.datasets.ImageFolder(),它继承自 torchvision.datasets.DatasetFolder(),后者又继承自 torchvision.datasets.VisionDataset(),而VisionDataset 则是 torch.utils.data.Dataset 的子类

torchvision.datasets.CIFAR数据集为例说明如何使用torchvision.datasets.ImageFolder(),这里的torchvision.datasets.CIFAR我已经将其转换为png格式存储

  • 下载链接
  • CIFAR10有60000张图片,共分为10个类别,其中50000张为训练图片(每个类别5000张),10000张为测试图片(每个类别1000张)

图片文件布局如下,torchvision.datasets.ImageFolder()要求你的图片数据必须按照以下方式进行组织

在这里插入图片描述

torchvision.datasets.ImageFoler参数说明

  • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录
  • transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片
  • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
  • loader:表示数据集加载方式,通常默认加载方式即可
torchvision.datasets.ImageFolder(root,transform,target_transform,loader)

如下,使用torchvision.datasets.ImageFoler对前面的图片进行加载

  • 注意transforms部分可暂时忽略
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms)
test_dataset = torchvision.datasets.ImageFolder(root='./data/test', transform=transforms.ToTensor)
print(len(train_dataset))
print(len(test_dataset))

在这里插入图片描述

同时,通过torchvision.datasets.ImageFolder生成的train_datasettest_dataset还有如下3个成员变量

  • self.classes:使用一个list保存类别名称
  • self.class_to_idx:类别对应的索引
  • self.imgs:是一个list,每个元素是一个tuple,每个tuple保存的是(img-path, class)
print(train_dataset.classes[: 5])
print("-"*30)
print(train_dataset.class_to_idx)
print("-"*30)
print(train_dataset.imgs[: 5])

在这里插入图片描述

(3)ImageFolder手动实现

仍然以上述CIFAR10数据集为例,我们手动实现一下ImageFolder,这对你理解它大有帮助

import torchvision.datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob


# 类别名字
label_name = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck"
]
# 类比名字映射索引
label_dict = {}
for idx, name in enumerate(label_name):
    label_dict[name] = idx


def default_loader(path):
    return Image.open(path).convert("RGB")

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

class MyDataset(Dataset):
    """
        im_list:是一个列表,每一个元素是图片路径
        transform:对图片进行增强
        loader:使用PIL对图片进行加载
    """
    def __init__(self, im_list, transform=None, loader=default_loader):
        super(MyDataset, self).__init__()
        # imgs为二维列表,每一个子列表中第一个元素存储im_list,第二个通过label_dict映射为索引
        imgs = []

        for im_item in im_list:
            # 路径'./data/test/airplane/aeroplane_s_000002.png'中倒数第二个是标签名
            im_label_name = im_item.split("\\")[-2]
            imgs.append([im_item, label_dict[im_label_name]])

        self.imgs = imgs
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        im__path, im_label = self.imgs[index]

        # 会调用PIL加载图片数据
        im_data = self.loader(im__path)
        # 如果给了transoform那么就对图片进行增强
        if self.transform is not None:
            im_data = self.transform(im_data)

        return im_data, im_label

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':
    im_train_list = glob.glob(r'./data/train/*/*.png')
    im_test_list = glob.glob(r'./data/test/*/*.png')

    train_dataset = MyDataset(im_train_list, transform=train_transforms)
    test_dataset = MyDataset(im_test_list, transform=transforms.ToTensor())
    print(len(train_dataset))
    print(len(test_dataset))

    train_loader = DataLoader(dataset=train_dataset, batch_size=6, shuffle=True, num_workers=0)
    test_loader = DataLoader(dataset=test_dataset, batch_size=6, shuffle=False, num_workers=0)

在这里插入图片描述

三:torchvision.transforms

torchvision.transforms:该模块是Pytorch中的图像预处理包,包含了一些常用的图像变换,主要实现对数据集的预处理、数据增强,转化为tensor等操作

使用时如果有很多变换,那么一般会使用Compose将这些步骤给整合到一起

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.ImageFolder(root='./data/train', transform=train_transforms

如果变换时只有一种,那么一般会直接给到形参,比如最常使用到的ToTensor

test_dataset = torchvision.datasets.ImageFolder(root='./data/test', transform=transforms.ToTensor)

torchvision.transforms涉及变换主要有以下4类

  • 裁剪

    • 中心裁剪transforms.CenterCrop
    • 随机裁剪transforms.RandomCrop
    • 随机长宽比裁剪transforms.RandomResizedCrop
    • 上下左右中心裁剪transforms.FiveCrop
    • 上下左右中心裁剪后翻转transforms.TenCrop
  • 翻转和旋转

    • 依概率p水平翻转transforms.RandomHorizontalFlip(p=0.5)
    • 依概率p垂直翻转transforms.RandomVerticalFlip(p=0.5)
    • 随机旋转transforms.RandomRotation
  • 图像变换和转换

    • 变换为某一尺寸transforms.Resize
    • 标准化transforms.Normalize
    • 转化为tensor并归一化transforms.ToTensor
    • 填充transforms.Pad
    • 修改亮度、对比度和饱和度transforms.ColorJitter
    • 转化为灰度图transforms.Grayscale
    • 线性变化transforms.LinearTransformation
    • 仿射变换transforms.RandomAffine
    • 依概率p转化为灰度图transforms.RandomGrayscale
    • 将数据转化为PILImagetransforms.ToPILImage
  • 其他操作

    • transforms操作使数据增强更灵活transforms.RandomChoice(transforms)
    • 从给定的一系列transforms中选定一个操作transforms.RandomApply(transforms, p=0.5)
    • 给一个transform加上概率进行操作transforms.RandomOrder

四:torchvision.models

torchvision.models:该模块提供了很多图像处理中的常用模型,并且提供了与训练版本,主要有

  • AlexNet: AlexNet variant from the “One weird trick” paper.
  • VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
  • ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
  • SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/137944.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【元宇宙欧米说】左手追星,右手造富——用Web3加持娱乐行业

娱乐圈如何才能与资本市场保持步调一致?Web3浪潮来袭,文娱行业如何才能踏上Web3世界的巨轮? 1月4日下午三点,VegaIdol联合创始人Linnea将以**“左手追星,右手造富——用Web3加持娱乐行业”为题,与大家共同…

界面组件DevExpress WinForms v22.2 -全新升级的皮肤和矢量图标

DevExpress WinForms拥有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。DevExpress WinForms能完美构建流畅、美观且易于使用的应用程序,无论是Office风格的界面,还是分析处理大批量的业务数据,它都能轻松胜…

RabbitMQ的简单介绍与使用

前言:大家好,我是小威,24届毕业生,曾经在某央企公司实习,目前入职某税务公司。本篇文章将记录和分享RabbitMQ相关的知识点。 本篇文章记录的基础知识,适合在学Java的小白,也适合复习中&#xff…

乐视宣布每周工作4天半

老板跑了,公司不但没倒,而且员工还过上了不加班不内卷的神仙生活。 典型的老虎不在家,规矩自己定啊! 神仙日子 前段时间,网上流传着一则消息,说乐视目前还有400多名员工,靠着《甄嬛传》版权和…

Cadence PCB仿真使用Allegro PCB SI 创建含差分对网络元器件的IBIS模型图文教程

⏪《上一篇》   🏡《总目录》   ⏩《下一篇》 1,概述 本文简单介绍使用Allegro PCB SI软件为BRD PCB设计文件中的含有差分对网络的元器件创建IBIS模型的方法。 2,创建方法 第1步:确定打开PCB文件的软件是 Allegro PCB SI 如果不是Allegro PCB SI,可执行File→Chan…

回望2022,依然值得仰望星空

转眼间 2022 年已经过去,这是我在 CSDN 创作的第二年,在文章的创作上也是脱离“博客新手”身份,正式蜕变为“博客老手”的一年,各方面收获颇丰。2021 初见 CSDN来到 CSDN 是在 2020 年的 11 月份,但是那时候并没有开始…

传感器与传感器通道

传感器 Def:以一定精确度 把 被测量转换为与之有确定对应关系的,便于应用的某种物理量的测量系统。 作用: 捕获并转换信息,非电量物理参数转换为电参数。 e.g: 速度 ->电压,电流 组成: 敏感元件(直接感受…

css移动端适配最佳实践

移动端适配,在移动端里经常有遇到,在不同分辨率移动端设备精确还原UI设计稿,这是一个令人抓狂的问题,好在有flex,box布局解决了自适应很大一部分问题。 在开始本文之前主要介绍几种笔者常用的适配方案 1、设置meta标…

美赛Day1

1 层次分析法 评价类问题 1.1 模型介绍 1.1.1 模型介绍 在对B的评价中,判断A个物体哪个最好。将B分为k个可以评价的方面分别进行打分(每个方面A个物体的分数和为1),最终对A个物体的k个方面加权求和进行比较。 1.1.2 解题思路…

剑指政企数智办公市场,通信厂商融云有何看家本领?

近年来,数字经济正在加速赋能千行百业,我国的政务办公也正加速由数字化向智能化深度扩展。在线办公市场从公有云到政企私有云的热度,已然节节攀升。近日,作为通信厂商被熟知的融云推出了“百幄”数智办公平台,正式宣布…

05数据结构——顺序表与链表

开始系统学习算法啦!为后面力扣和蓝桥杯的刷题做准备!这个专栏将记录自己学习算法是的笔记,包括概念,算法运行过程,以及代码实现,希望能给大家带来帮助,感兴趣的小伙伴欢迎评论区留言或者私信博…

【DETR】DETR训练VOC数据集/自定义数据集

训练DETR一、数据准备二、配置DETRReferences一、数据准备 DETR用的是COCO格式的数据集。 如果要用DETR训练自己的数据集,直接利用Labelimg标注成COCO格式。 如果是VOC数据集的话,要做一个格式转换。网上一大堆格式转换的代码都很乱,所以自己…

java基于springboot外卖系统在线订餐系统app源码厨艺论坛APP

简介 本项目主要包括了外卖订餐系统(在线订餐和外卖配送)、厨艺论坛系统、管理员后台、用户中心等功能。用户注册后可以选择餐桌在线点餐支付,也可以选择外卖配送到家的方式。 演示视频 https://www.bilibili.com/video/BV1xv411t7JD/?sha…

Thinkphp5框架简单理解

说明 该文章来源于同事lu2ker转载至此处,更多文章可参考:https://github.com/lu2ker/ 目录说明TP5框架简单理解1. 架构总览1.1 控制器/操作1.2 MVC模式流程1.3 类库自动加载1.4 URL访问检测1.5 路由模式1.5.1 普通模式1.5.2 混合模式1.5.4 强制路由1.6 …

数据结构与算法学习——栈结构

在程序设计中,一定接触过“堆栈”的概念。其实,“栈 ” 和 “堆 ” 是两个不同的概念。这里,栈是一种特殊的数据结构,在中断处理特别是重要数据的现场保护有着重要意义。 什么是栈结构 从数据的逻辑结构来看,栈结构其…

59. 微调(fine-tuning)代码实现

1. 热狗识别 让我们通过具体案例演示微调:热狗识别。 我们将在一个小型数据集上微调ResNet模型。该模型已在ImageNet数据集上进行了预训练。 这个小型数据集包含数千张包含热狗和不包含热狗的图像,我们将使用微调模型来识别图像中是否包含热狗。 %matp…

专访中银金科:数字驱动成为新的增长引擎,未来业务转化是关键

大数据和信息科技正在逐步颠覆银行业过往的业务模式。建立以数据驱动为核心,以优化客户体验为目标的可持续营销理念,逐渐成为行业的共识。但是,伴随着银行业数字化转型进程加速发展,海量客户数据和低效营销之间的矛盾日益凸显。在…

Linux apt 命令

apt(Advanced Packaging Tool)是一个在 Debian 和 Ubuntu 中的 Shell 前端软件包管理器。 apt 命令提供了查找、安装、升级、删除某一个、一组甚至全部软件包的命令,而且命令简洁而又好记。 apt 命令执行需要超级管理员权限(root)。 apt 语…

23.2、Junit单元测试反射注解

Java代码执行的三个阶段 Junit单元测试: * 测试分类: 1. 黑盒测试:不需要写代码,给输入值,看程序是否能够输出期望的值。 2. 白盒测试:需要写代码的。关注程序具体的执行流程。 * Junit使用&#…

洛谷千题详解 | P1030 [NOIP2001 普及组] 求先序排列【C/C++、pascal语言】

博主主页:Yu仙笙 专栏地址:洛谷千题详解 目录 题目描述 输入格式 输出格式 输入输出样例 解析: C源码: C源码2: pascal源码: C源码: --------------------------------------------------------…