数据集相关类代码回顾理解 | StratifiedShuffleSplit\transforms.ToTensor\Counter
目录
np.mean
transforms.Normalize
transforms.Compose
xxx.transform
np.mean
meanRGB=[np.mean(x.numpy(),axis=(1,2)) for x,_ in train_ds]
计算每个样本的(RGB)均值 。NumPy 库np.mean 函数,用于计算数组元素的平均值。接受一个数组作为输入,返回一个标量值,表示数组中所有元素的平均值。axis 参数用于指定在计算平均值时沿哪个轴进行操作,axis 参数可以是一个整数或一个整数元组。axis=(1,2) 是一个元组,表示在高度和宽度维度上计算平均值,即对每个通道(如RGB)的值进行平均。辨析来看下一步:
meanR=np.mean([m[0] for m in meanRGB])
计算所有样本的R通道均值的均值。同理m[1]、m[2]分别可以计算所有样本G、B通道均值的均值,这样得到的实际上是一个最终均值,同理可以得到最终标准差,为标准化做准备。
transforms.Normalize
transforms.Normalize([meanR, meanG, meanB], [stdR, stdG, stdB])])
对图像进行标准化,将像素值缩放到均值为0标准差为1的范围,均值和标准差分别为meanR, meanG, meanB和stdR, stdG, stdB。transforms.Normalize是PyTorch中的图像预处理函数,作用是将图像的每个像素值减去均值,然后再除以标准差,从而将像素值缩放到均值为0,标准差为1的范围。需要注意的是,这个过程实际对应的是标准化Standardization,而非常规的线性归一化,归一化通常在标准化之前组合使用。例如线性归一到[0,1]范围,通常通过transforms.ToTensor实现,参照http://t.csdnimg.cn/5FSvB。
有资料将Standardization标准化称为zero-mean unit-variance normalization零均值单位方差归一化,本质相同
transforms.Compose
test0_transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([meanR, meanG, meanB], [stdR, stdG, stdB]),
])
transforms.Compose 是 torchvision.transforms 模块中的一个类,用于将多个图像变换操作组合成一个整体。以一个包含多个变换操作的列表作为输入,并按照列表中的顺序依次对图像进行变换。上述代码就将以下操作组合成了一个整体:图像转化为张量,像素值从 [0, 255] 范围归一缩放到 [0, 1] 范围,通道顺序从 HWC(高度、宽度、通道)转换为 CHW(通道、高度、宽度),并对图像进行标准化,将像素值缩放到均值为0标准差为1的范围,均值和标准差分别为meanR, meanG, meanB和stdR, stdG, stdB。例如前面transforms.Normalize提到的,这里就包括了归一化和标准化的组合使用。
transforms.ToTensor()的作用是将PIL图像或NumPy数组转换为PyTorch张量,并且将图像的像素值从[0, 255]范围缩放到[0.0, 1.0]范围,即在[0.0, 1.0]范围内对像素值进行归一化。转换后的张量形状为(C, H, W)。
xxx.transform
test0_ds.transform=test0_transformer
test0_ds.transform 是一个属性,test0_ds.transform=test0_transformer将测试数据集的转换器赋给测试数据集的transform属性,意味着在从 test0_ds 中获取样本时,每个样本都会先经过test0_transformer进行预处理。