【PyTorch】图像多分类项目
目录
StratifiedShuffleSplit
transforms.ToTensor
Counter
StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
创建StratifiedShuffleSplit对象,用于将数据集划分为训练集和测试集。
- n_splits=1:划分次数为1,大于1则多次划分,每次划分生成一组新训练集和新测试集。
- test_size=0.2:测试集比例为0.2,即测试集的大小占总样本的20%
- random_state=0:随机种子为0,类似random的种子,保证每次抽样到的数据一样
StratifiedShuffleSplit是scikit-learn库中的一个类,用于创建训练集和测试集的划分,同时保持每个类别中的样本比例一致。核心思想:分层抽样。
StratifiedShuffleSplit 类的工作原理:
先根据每个类别的样本数量将数据集划分为尽可能相等的子集(分层)
然后在这些子集中随机选择样本拆分创建训练集和测试集(随机拆分)
插入空格更好理解:Stratified Shuffle Split分层随机拆分类!
transforms.ToTensor
data_transformer = transforms.Compose([transforms.ToTensor()])
transforms.ToTensor()的作用是将PIL图像或NumPy数组转换为PyTorch张量,并且将图像的像素值从[0, 255]范围缩放到[0.0, 1.0]范围,即在[0.0, 1.0]范围内对像素值进行归一化。转换后的张量形状为(C, H, W)
Compose是 torchvision.transforms 模块的一个类,创建一个Compose对象时,需要传入一个包含一个或多个变换操作的列表。Compose对象一般包含四个变换操作:调整图像大小、从中心裁剪图像、将图像转换为张量以及归一化。
Counter
counter_train=collections.Counter(y_train)
用于统计图像标签,即每类标签图像数量,Counter是用于计数的子类字典。例如PyTorch torchvision包中STL-10数据集的训练数据集: