来源:投稿 作者:阿克西
编辑:学姐
建议搭配视频食用
视频链接:https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6
系列其他文章传送门:
pytorch基础语法学习:数据读取机制Dataloader与Dataset
pytorch基础语法(一)
pytorch基础语法(二)
1.transforms运行机制
torchvision是pytorch的计算机视觉工具包,主要有以下三个模块:
-
torchvision.transforms
:提供了常用的一系列图像预处理方法,例如数据的标准化,中心化,旋转,翻转等。 -
torchvision.datasets
:定义了一系列常用的公开数据集的datasets,比如MNIST,CIFAR-10,ImageNet等。 -
torchvision.model
:提供了常用的预训练模型,例如AlexNet,VGG,ResNet,GoogLeNet等。
torchvision.transforms:常用的图像预处理方法
-
数据中心化,数据标准化
-
缩放,裁剪,旋转,翻转,填充
-
噪声添加,灰度变换,线性变换,仿射变换
-
亮度、饱和度及对比度变换
深度学习是由数据驱动的,数据的数量以及分布对模型的优劣起到决定性作用,所以需要对数据进行一定的预处理以及数据增强,用来提升模型的泛化能力。
上图是1张原始图片经过数据增强之后生成的一系列数据,一共有64张图片。对图片进行数据增强可以丰富训练数据,提高模型的泛化能力。因为如果数据增强生成了与测试样本很相似的图片,那么模型的泛化能力自然可以得到提高。
使用上一节中介绍的人民币二分类实验的代码的数据预处理部分:
2.断点调试
# ============================ step 1/5 数据 ============================
# 这部分设置数据的路径
split_dir = os.path.join("C:/Users/10530/Desktop/pytorch/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
# 设置数据标准化的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
# transforms.Compose的功能是将一系列的transforms方法进行有序的组合包装,
# 在具体实现的时候,会依次按顺序对图像进行操作
train_transform = transforms.Compose([
transforms.Resize((32, 32)), # 将图像缩放到32*32的大小
transforms.RandomCrop(32, padding=4), # 对数据进行随机的裁剪
# 将图片转成张量的形式同时会进行归一化操作,把像素值的区间从0-255归一化到0-1
transforms.ToTensor(),
# 标准化操作,将数据的均值变为0,标准差变为1
transforms.Normalize(norm_mean, norm_std),
])
# 验证集的预处理的方法,对比训练集,少了RandomCrop这一部分,
# 因为在验证集中是不需要对数据进行数据增强的
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
同样,在模型训练样本读取位置设置断点,进行debug:
点击step into按键,在跳转后的代码中进行一个是否采用多进程的判断:
点击step over,选择单进程的运行机制,再点击step into按键,进入dataloader.py界面:
光标设置在index = self._next_index() # may raise StopIteration这一行,点击Run to Cursor,程序就会运行到光标所在的行。这一步的作用是获取Index,也就是要读取哪些数据。点击step over,得到Index就可以进入dataset_fetcher.fetch(index),根据索引去获取数据。点击step into进入到fetch函数:
在fetch函数中,代码data = [self.dataset[idx] for idx in possibly_batched_index]使用了列表生成式,调用了dataset,接着点击step over与step into进入dataset所在的代码位置,dataset代码位于类RMBDataset(Dataset)中的__getitem__()函数:
在getitem()中根据索引去获取图片的路径以及标签,然后采用代码img = Image.open(path_img).convert('RGB') # 0~255打开图片,读取进来的图片是一个PIL的数据类型,然后在getitem中调用transform()进行图像预处理操作,在代码处img = self.transform(img)通过step into进入transforms.py中的def 「call」()函数
「call」()函数是一个for循环,也就是依次有序地从compose中去调用预处理方法,第一个预处理方法是t(img),其功能是是Resize缩放;第二个功能是裁剪,第三个功能是进行张量操作,第四个功能是进行归一化;对compose的四个功能循环结束之后,就会返回代码处img = self.transform(img)。
transform是在__getitem__()中调用,并且在__getitem__()中实现数据预处理,然后通过__getitem__返回一个样本。
执行step out操作返回fetch()函数,接着就是不断地循环index获取一个batch_size大小的数据,最后在return的时候调用collate_fn()函数,将数据整理成一个batch_data的形式。
然后执行step out操作返回到dataloader.py中的__next__()函数中,然后再执行执行step out操作回到训练代码中,接着数据就读取进来了。这就是pytorch数据读取和transforms的运行机制。
回顾上面的数据读取流程图,transforms是在getitem中使用的,在getitem中读取一张图片,然后对这一张图片进行一系列预处理,返回图片以及标签。
了解了transforms的机制,现在学习一个比较常用的预处理方法,数据的标准化transforms.Normalize。
3.数据标准化transforms.normalize
3.1 定义
功能:逐channel的对图像进行标准化,即数据的均值变为0,标准差变为1。
计算公式:
-
mean:各通道的均值
-
std:各通道的标准差
-
inplace:是否原位操作
transform.Normalize(mean,
std,
inplace=False)
3.2 断点调试
回到代码中看一下normalize的具体实现方法,transform是在dataset的getitem中实现的,所以可以直接去dataset的getitem函数中设置断点:
进行debug操作,点击step into进入详细代码环境,进入了transforms.py中的call()函数中,在call函数中循环transforms。
点击step over执行多次,到normalize实现
接着点击step into查看normalize的实现,来到了normalize()类中的__call__()函数中,代码只有一行,实际上这行代码是调用了pytorch中的function中normalize方法。pytorch的function提供了很多常用的函数。
接着使用step into查看normalize中的具体实现。
def normalize(tensor, mean, std, inplace=False):
"""Normalize a tensor image with mean and standard deviation.
.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.
See :class:`~torchvision.transforms.Normalize` for more details.
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.
Returns:
Tensor: Normalized Tensor image.
"""
if not _is_tensor_image(tensor): # 输入的合法性判断
raise TypeError('tensor is not a torch image.')
if not inplace: # 判断是否需要原地操作
tensor = tensor.clone()
dtype = tensor.dtype
# 获取均值与标准差,将list形式转变为张量形式
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) # 归一化公式
return tensor
首先是输入的合法性判断,输入的是tensor,也就是原始的图像,接着判断是否要原地操作,如果不是inplace就需要将张量复制一份到新的内存空间中。下面的代码就是获取数据的均值和标准差,并将数据转换为张量。注意在sub_和div_后面有下划线,意思是进行原位操作,这样就完成了数据标准化的操作。
3.3 标准化作用
对数据进行标准化之后可以加快模型的收敛。
之前的逻辑回归代码bias=1,发现迭代次数360次即可得到99%的准确率,损失loss=0.05。
当修改bias=5时,发现需要迭代960次模型才能收敛,loss=0.14,得到99%的准确率。
原因:模型初始化一般有0均值,需要逐渐靠近最优分类平面。
bias=5的初始化距离分类平面较远
可以看出,如果训练数据有良好的分布或者权重有良好的初始化,可以加速模型的训练。
点击下方卡片《学姐带你玩AI》🚀🚀🚀
关注回复“500”领取300+经典论文合集&讲解视频
码字不易,欢迎大家点赞评论收藏!