pytorch基础语法学习:数据预处理transforms模块

news2024/12/25 12:19:23

来源:投稿 作者:阿克西
编辑:学姐

建议搭配视频食用

视频链接:https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6

系列其他文章传送门:

pytorch基础语法学习:数据读取机制Dataloader与Dataset

pytorch基础语法(一)

pytorch基础语法(二)

1.transforms运行机制

torchvision是pytorch的计算机视觉工具包,主要有以下三个模块:

  • torchvision.transforms:提供了常用的一系列图像预处理方法,例如数据的标准化,中心化,旋转,翻转等。

  • torchvision.datasets:定义了一系列常用的公开数据集的datasets,比如MNIST,CIFAR-10,ImageNet等。

  • torchvision.model:提供了常用的预训练模型,例如AlexNet,VGG,ResNet,GoogLeNet等。

torchvision.transforms:常用的图像预处理方法

  • 数据中心化,数据标准化

  • 缩放,裁剪,旋转,翻转,填充

  • 噪声添加,灰度变换,线性变换,仿射变换

  • 亮度、饱和度及对比度变换

深度学习是由数据驱动的,数据的数量以及分布对模型的优劣起到决定性作用,所以需要对数据进行一定的预处理以及数据增强,用来提升模型的泛化能力。

上图是1张原始图片经过数据增强之后生成的一系列数据,一共有64张图片。对图片进行数据增强可以丰富训练数据,提高模型的泛化能力。因为如果数据增强生成了与测试样本很相似的图片,那么模型的泛化能力自然可以得到提高。

使用上一节中介绍的人民币二分类实验的代码的数据预处理部分:

2.断点调试

# ============================ step 1/5 数据 ============================
# 这部分设置数据的路径
split_dir = os.path.join("C:/Users/10530/Desktop/pytorch/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

# 设置数据标准化的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

# transforms.Compose的功能是将一系列的transforms方法进行有序的组合包装,
# 在具体实现的时候,会依次按顺序对图像进行操作
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图像缩放到32*32的大小
    transforms.RandomCrop(32, padding=4),  # 对数据进行随机的裁剪
    # 将图片转成张量的形式同时会进行归一化操作,把像素值的区间从0-255归一化到0-1
    transforms.ToTensor(),
    # 标准化操作,将数据的均值变为0,标准差变为1
    transforms.Normalize(norm_mean, norm_std),
])

# 验证集的预处理的方法,对比训练集,少了RandomCrop这一部分,
# 因为在验证集中是不需要对数据进行数据增强的
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

同样,在模型训练样本读取位置设置断点,进行debug:

点击step into按键,在跳转后的代码中进行一个是否采用多进程的判断:

点击step over,选择单进程的运行机制,再点击step into按键,进入dataloader.py界面:

光标设置在index = self._next_index() # may raise StopIteration这一行,点击Run to Cursor,程序就会运行到光标所在的行。这一步的作用是获取Index,也就是要读取哪些数据。点击step over,得到Index就可以进入dataset_fetcher.fetch(index),根据索引去获取数据。点击step into进入到fetch函数:

在fetch函数中,代码data = [self.dataset[idx] for idx in possibly_batched_index]使用了列表生成式,调用了dataset,接着点击step over与step into进入dataset所在的代码位置,dataset代码位于类RMBDataset(Dataset)中的__getitem__()函数:

在getitem()中根据索引去获取图片的路径以及标签,然后采用代码img = Image.open(path_img).convert('RGB') # 0~255打开图片,读取进来的图片是一个PIL的数据类型,然后在getitem中调用transform()进行图像预处理操作,在代码处img = self.transform(img)通过step into进入transforms.py中的def 「call」()函数

「call」()函数是一个for循环,也就是依次有序地从compose中去调用预处理方法,第一个预处理方法是t(img),其功能是是Resize缩放;第二个功能是裁剪,第三个功能是进行张量操作,第四个功能是进行归一化;对compose的四个功能循环结束之后,就会返回代码处img = self.transform(img)。

transform是在__getitem__()中调用,并且在__getitem__()中实现数据预处理,然后通过__getitem__返回一个样本。

执行step out操作返回fetch()函数,接着就是不断地循环index获取一个batch_size大小的数据,最后在return的时候调用collate_fn()函数,将数据整理成一个batch_data的形式。

然后执行step out操作返回到dataloader.py中的__next__()函数中,然后再执行执行step out操作回到训练代码中,接着数据就读取进来了。这就是pytorch数据读取和transforms的运行机制。

回顾上面的数据读取流程图,transforms是在getitem中使用的,在getitem中读取一张图片,然后对这一张图片进行一系列预处理,返回图片以及标签。

了解了transforms的机制,现在学习一个比较常用的预处理方法,数据的标准化transforms.Normalize。

3.数据标准化transforms.normalize

3.1 定义

功能:逐channel的对图像进行标准化,即数据的均值变为0,标准差变为1。

计算公式:output =\frac{(input - mean)}{std}

  • mean:各通道的均值

  • std:各通道的标准差

  • inplace:是否原位操作

transform.Normalize(mean,
                    std,
                    inplace=False)

3.2 断点调试

回到代码中看一下normalize的具体实现方法,transform是在dataset的getitem中实现的,所以可以直接去dataset的getitem函数中设置断点:

进行debug操作,点击step into进入详细代码环境,进入了transforms.py中的call()函数中,在call函数中循环transforms。

点击step over执行多次,到normalize实现

接着点击step into查看normalize的实现,来到了normalize()类中的__call__()函数中,代码只有一行,实际上这行代码是调用了pytorch中的function中normalize方法。pytorch的function提供了很多常用的函数。

接着使用step into查看normalize中的具体实现。

def normalize(tensor, mean, std, inplace=False):
    """Normalize a tensor image with mean and standard deviation.

    .. note::
        This transform acts out of place by default, i.e., it does not mutates the input tensor.

    See :class:`~torchvision.transforms.Normalize` for more details.

    Args:
        tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation inplace.

    Returns:
        Tensor: Normalized Tensor image.
    """
    if not _is_tensor_image(tensor):  # 输入的合法性判断
        raise TypeError('tensor is not a torch image.')

    if not inplace:       # 判断是否需要原地操作
        tensor = tensor.clone()

    dtype = tensor.dtype
    # 获取均值与标准差,将list形式转变为张量形式
    mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
    std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) # 归一化公式
    return tensor

首先是输入的合法性判断,输入的是tensor,也就是原始的图像,接着判断是否要原地操作,如果不是inplace就需要将张量复制一份到新的内存空间中。下面的代码就是获取数据的均值和标准差,并将数据转换为张量。注意在sub_和div_后面有下划线,意思是进行原位操作,这样就完成了数据标准化的操作。

3.3 标准化作用

对数据进行标准化之后可以加快模型的收敛。

之前的逻辑回归代码bias=1,发现迭代次数360次即可得到99%的准确率,损失loss=0.05。

当修改bias=5时,发现需要迭代960次模型才能收敛,loss=0.14,得到99%的准确率。

原因:模型初始化一般有0均值,需要逐渐靠近最优分类平面。

bias=5的初始化距离分类平面较远

可以看出,如果训练数据有良好的分布或者权重有良好的初始化,可以加速模型的训练。

点击下方卡片《学姐带你玩AI》🚀🚀🚀

关注回复“500”领取300+经典论文合集&讲解视频

码字不易,欢迎大家点赞评论收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/560192.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3D点云数据转为俯瞰图Python实现代码

我主要是参考了英文博客来撰写本篇文章,仅作为个人学习笔记参考使用。 文章目录 一、点云数据二、图像与点云坐标三、创建点云数据的鸟瞰视图3.1 鸟瞰图的相关坐标轴3.2 限制点云数据范围3.3 将点位置映射到像素位置3.4 切换到新的零点3.5 像素值3.6 创建图像矩阵3.…

IOS最新版开通GPT-PLUS方法

前提,美国IP魔法 不多说了 1.拥有一个美区apple id账号 可以买,也可以自己申请 自己申请就打开魔法到apple官网注册,用gmail邮箱,然后地址用美国地址生成器,记得选免税州 2.充值礼品卡 支付宝可以充值礼品卡&…

大模型总是「胡说八道」怎么办?手把手教你如何应对!

随着 ChatGPT 的出现,「AI 幻觉」一词被频繁提及。那么,什么是 AI 幻觉?简单来说,就是大模型在一本正经地胡说八道。 不止 ChatGPT,其他大语言模型也经常如此,究其根本是大语言模型在训练的过程中存在数据偏…

驱动开发-----io模型总结(2023-5-23)

1.非阻塞模型 在我们使用open函数时,将打开的驱动设置为O_NONBLOCK时,当我们用read函数去读取硬件数据时,无论硬件是否有数据,都会往下执行,不会被阻塞在这里 2.阻塞模型 在我们使用open函数时,没有设置…

C++学习之路-变量和基本内置类型

变量和基本内置类型 一、基本内置类型1.1 算数类型1.2 带符号类型和无符号类型1.3 类型转换含有无符号类型的表达式 1.4 字面值常量整形和浮点型字面值字符和字符串字面值转义序列指定字面值的类型 二、变量2.1 变量的定义初始化列表初始化默认初始化 2.2 变量声明和定义的关系…

斐波那契数列数列相关简化1

斐波那契数列问题介绍: 斐波那契数列(Fibonacci sequence),又称黄金分割数列,因数学家莱昂纳多斐波那契(Leonardo Fibonacci)以兔子繁殖为例子而引入,故又称为“兔子数列”&#xf…

包管理工具详解npm、yarn、cnpm、npx、pnpm

目录: 1 npm包管理工具 2 package配置文件 3 npm install原理 4 yarn、cnpm、npx 5 发布自己的开发包 6 pnpm使用和原理 当我们使用npm install xxxx 的时候会添加一个node_module和2个json文件: package.json是配置信息文件,  这个配…

Go完整即时通讯项目及Go的生态介绍

Go完整即时通讯项目 项目架构: 1 编写基本服务端-Server server.go package mainimport ("fmt""net" )// 定义服务端 type Server struct {ip stringport int }// 创建一个Server func NewServer(ip string, port int) *Server {return …

Jenkins + docker-compose 在 Centos 上搭建部署

一、前期准备 1. 检查 CentOS上 是否安装 docker 可以使用以下命令: sudo docker version 如果已经安装了Docker,它将显示有关Docker版本和构建信息的输出。如果未安装Docker,将收到有关命令未找到的错误消息。 2. 检查是否安装 docker-…

cookie-机制

目录 一、基础概念 二、cookie的处理方式 一、基础概念 1、cookie是存储在客户端的一组键值对 2、web中cookie的典型应用:免密登陆 3、cookie和爬虫之间的关联 有时,对一张页面进行请求的时候,如果请求的过程中不携带cookie的话&#xf…

Openai+Coursera: ChatGPT Prompt Engineering(四)

想和大家分享一下最近学习的Coursera和openai联合打造ChatGPT Prompt Engineering在线课程.以下是我写的关于该课程的前两篇博客: ChatGPT Prompt Engineering(一)ChatGPT Prompt Engineering(二)ChatGPT Prompt Engineering(三) 今天我们来学习第三部分内容&…

Java on Azure Tooling 4月更新|路线图更新及 Azure Toolkit for IntelliJ 增强

作者:Jialuo Gan - Program Manager, Developer Division at Microsoft 排版:Alan Wang 大家好,欢迎来到 Java on Azure 工具产品的4月更新。让我们首先来谈谈我们对未来几个月的 Java on Azure 开发工具的投资。在这次更新中,我们…

js - 闭包

1、闭包的概念 闭包:函数嵌套函数,内层函数访问了外层函数的局部变量。 // 闭包 function func1() {let a 9;let b 8;function func2() {console.log("a", a); // a 9}func2(); } func1(); 分析: 需要访问的变量会被放到闭包…

【云原生|Kubernetes】05-Pod的存储卷(Volume)

【云原生Kubernetes】05-Pod的存储卷(Volume) 文章目录 【云原生Kubernetes】05-Pod的存储卷(Volume)简介Volume类型解析emptyDirHostPathgcePersistentDiskNFSiscsiglusterfsceph其他volume 简介 Volume 是Pod 中能够被多个容器访问的共享目录。 Kubern…

ChatGPT可以帮助开发人员的8种方式...

“适应或灭亡”是科技界的口头禅,如果您是开发人员,则尤其如此。 由于技术的动态发展,开发人员面临着比大多数人更大的压力,他们要领先于适应和精通最好的工具。ChatGPT 是最新的此类工具。 虽然有人说 ChatGPT 是“工作杀手”&…

比Figma更丝滑的“Figma网页版“

随着互联网的全面普及和全球化,设计协作工具逐渐成为团队协作中不可或缺的一部分。设计师们常需要通过在线设计协作工具来完成设计任务,而 Figma 作为协作工具的佼佼者,成为了许多设计师心中的首选。但是,对于国内设计师来说&…

Leetcode406. 根据身高重建队列

Every day a Leetcode 题目来源:406. 根据身高重建队列 解法1:贪心 题解:根据身高重建队列 我们先按照身高从大到小排序(身高相同的情况下K小的在前面),这样的话,无论哪个人的身高都小于等于…

kubeadm安装集群的时候kube-proxy是如何安装的

背景 最近升级k8s集群时遇到这个问题,集群是使用kuberadm自动化脚本安装的,之前一直认为kubeadm安装的集群这些组件除了kubelet都是静态pod跑起来的。 其实kube-proxy并不是. kube-proxy是如何安装的 在使用kubeadmin安装Kubernetes集群时&#xff0c…

Echarts通过Jquery添加下拉列表动态改变展示的数据和图表

前言 在项目中,有时候我们会一些需求,比如要用Echarts绘制一个饼状图,并且要设置一个下拉列表,当我点击某个选项的时候,饼状图里面的数据会改变,图表样式也会发生改变。我们可以配合Jquery来实现这个功能。…

数字电路基础

目录 一、不同进制之间的转换 二、逻辑代数基础 三、门电路 四、组合逻辑电路 五、半导体存储电路 六、时序电路 一、不同进制之间的转换 二-十转换: 十-二转换: 二-十六转换 十六-二转换 八-二转换 二-八转换 十六-十转换: 先转换成…