最近在跟着小土堆pytorch的视频跟着学习python,根据自己的理解和课程上面的知识,写了这一篇学习笔记。
1、加载数据
数据的加载是学习pytorch的第一步,我们需要加载数据,完成特征工程,对加载数据存在的一些特征来进行分析和处理,进而利用相关算法训练得到模型。
数据该如何加载呢?
首先,如果是文本之类数据的话,可以使用open()函数进行文件读取操作,对于图片的话,可以使用PIL下面的一个api,调用里面的open()方法来打开图片,如果要使用,则需要进行导包的操作。
from PIL import Image
如果导需要从某种数据源加载数据,并对这些数据进行预处理和格式化的话,利用pytorch中的Dataset类是最为方便的,也需要导包。
from torch.utils.data import Dataset
Dataset类里面定义了两种方法:
__len__()
: 返回数据集中的样本数量。__getitem__(idx)
: 根据给定的索引idx
返回一个样本
我们需要自己定义一个类,这个类继承Dataset类,并重写相关方法
需要调用系统路径,导入os模块,不要忘记了
import os
此时先定义一个自定义类,这个自定义类继承于Dataset类
class MyData(Dataset):
然后重写Dataset里面的__init__和__getitem__方法
class MyData(Dataset):
# 初始化方法,当创建MyData对象时会被调用
# root_dir: 数据集的根目录
# label_dir: 自定义的类别目录,通常是某个类别的子目录
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 存储根目录
self.label_dir = label_dir # 存储类别目录
# 拼接根目录和类别目录,得到完整路径
self.path = os.path.join(self.root_dir, self.label_dir)
# 获取该类别目录下的所有文件/图片名,并存储到self.img_path列表中
self.img_path = os.listdir(self.path)
# 根据索引获取数据集中的单个样本
# idx: 样本的索引
def __getitem__(self, idx):
# 获取索引对应的图片名
img_name = self.img_path[idx]
# 拼接完整的图片路径
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
# 打开图片并获取图片对象
image = Image.open(img_item_path)
# 使用类别目录作为标签
label = self.label_dir
# 返回图片对象和标签
return image, label
记得在py文件统计目录下有相关文件,我是跟着小土堆的课程,所以就下载了dataset的数据集
最后定义好相关的变量和函数即可
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
这样就能得到ants_dataset和bees_dataset两个类别的数据集
2、查看数据
查看数据集的话,需要用到DataLoader数据加载类来加载数据,调用Transform来对数据进行增强,通过使用 transforms.Compose
来组合多个转换操作(这是后面要学习的)
下面的代码可以看下,但不用深究,看个大致就行。
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
from torchvision import transforms
class MyData(Dataset):
def __init__(self, root_dir, label_dir, transform=None):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_names = os.listdir(self.path)
self.transform = transform
self.label = self.label_dir.replace("_image", "") # 假设标签是目录名的前缀
def __getitem__(self, idx):
img_name = self.img_names[idx]
img_path = os.path.join(self.path, img_name)
image = Image.open(img_path)
if self.transform:
image = self.transform(image)
return image, self.label
def __len__(self):
return len(self.img_names)
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
# 定义数据增强
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整为适合模型输入的尺寸
transforms.ToTensor(), # 将 PIL Image 或 numpy.ndarray 转换为 torch.FloatTensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 的均值和标准差
])
ants_dataset = MyData(root_dir, ants_label_dir, transform=transform)
bees_dataset = MyData(root_dir, bees_label_dir, transform=transform)
# 创建 DataLoader
batch_size = 4
train_loader = DataLoader(ants_dataset, batch_size=batch_size, shuffle=True)
# 查看数据集
for images, labels in train_loader:
print("Images batch shape:", images.shape)
print("Labels batch:", labels)