【Pytorch】计算机视觉项目——卷积神经网络TinyVGG模型图像分类(如何使用自定义数据集)

news2025/1/13 13:42:53

目录

  • 一、前言
  • 二、工作流程回顾
  • 三、详细步骤流程
    • 1. 环境配置
    • 2. 数据准备
      • 数据集下载
      • 数据存储结构&路径
      • 查看图片
    • 3. 数据转换
    • 4. 自定义数据集(Custom Dataset )
      • 4.1 方法一:使用`ImageFolder`加载数据集
        • 信息查看
        • 张量转图片
        • 创建DataLoader
      • 4.2 方法二:自定义Dataset
        • 导入所需的库
        • 类名辅助函数
        • 复刻`Imagefolder`功能
        • 可视化检查
        • 创建DatasetLoader
    • 5. 模型搭建(TinyVGG)
    • 6. 模型训练
      • 定义训练和测试循环
      • 模型训练
    • 7. 模型评估
    • 8. 模型预测(暂略)
  • 四、补充说明:自定义数据集代码
      • 1. 类的定义和初始化
      • 2. 获取数据集的大小
      • 3. 获取特定索引的数据项
      • 小结
  • 参考资料


一、前言

在上一篇笔记中,介绍了如何搭建CNN模型完成图像分类任务,项目中使用了torchvision的内置图像数据集FashionMNIST。

然而,如果使用其他的图像数据集,比如我们自己的数据集或者其他分类任务数据集,又该如何操作呢?本文主要分享如何创建自定义数据集:

  1. 使用ImageFolder创建自定义数据集的步骤流程
  2. 通过复刻ImageFolder的功能,创建自定义数据集

其他相关文章:

  • 深度学习入门笔记:总结了一些神经网络的基础概念。
  • TensorFlow专栏:《计算机视觉入门系列》介绍如何用TensorFlow框架实现卷积分类器。
  • 【Pytorch】整体工作流程代码详解(新手入门)

二、工作流程回顾

![[04-20240604174526325.webp]]

  1. 数据准备:导入数据集、设置DataLoader (这部分是本文核心重点)
  2. 模型搭建
  3. 模型训练: 设定Loss和优化器、训练&测试循环
  4. 模型评估和结果输出
  5. 模型导出保存

三、详细步骤流程

1. 环境配置

import torch
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"

print(torch.__version__)
print(device)

![[04-20240605001014043.webp]]


2. 数据准备

数据集下载

数据集来自Food101,详细介绍见:Food-101 – Mining Discriminative Components with Random Forests

![[04 Custom Dataset-20240602173421515.webp]]

import requests
import zipfile
from pathlib import Path

# 路径设置
data_path = Path("data/")
image_path = data_path / "pizza_steak_sushi"

# 检查是否存在文件,如果没有下载数据
if image_path.is_dir():
  print(f"{image_path} directory exists.")
else:
  print(f"Did not find {image_path} directory, creating one...")
  image_path.mkdir(parents=True, exist_ok=True)
  
  # 下载数据文件
  with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
    print("Downloading pizza, steak, sushi data...")
    f.write(request.content)

  # 解压文件
  with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
    print("Unzipping pizza, steak, sushi data...")
    zip_ref.extractall(image_path)

数据存储结构&路径

下载好后,在一侧的文件夹里可以找到下载好的图片文档。
![[04 Custom Dataset-20240602174550247.webp|269]]

通常在图像识别的任务中,数据集都是以这种格式存储的,例如pizza文件夹里会存所有披萨的照片,sushi文件夹则存放所有寿司照片。

定义遍历函数

在数据准备阶段,需要把文件转换成适用于Pytorch学习使用的格式,其中也包括文件对应的存储路径(文件夹名称是图像分类标签,因此,文件路径也是训练内容的一部分)

首先,定义一个函数,去遍历计数文件夹里的文件,并返回对应文件的路径。

import os

def walk_through_dir(dir_path):
  for dirpath, dirnames, filenames in os.walk(dir_path):
    print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")

walk_through_dir(image_path)

在这里插入图片描述

结果显示,我们每个训练类型中都差不多有75张照片,而测试集每类约有20多张。

设置训练集和测试集路径

train_dir = image_path / "train"
test_dir = image_path / "test"

train_dir, test_dir

在这里插入图片描述

查看图片

接下来,我们随机抽取一个图片打开,有两种打开的方式:

  1. PIL.Image.open()
  2. pyplot.imshow()
import random
from PIL import Image

# 设置随机种子
random.seed(42)

# 1. 获取所有图片(.jpg结尾的文件)路径
image_path_list = list(image_path.glob("*/*/*.jpg"))

# 2. 随机选择一个图片
random_image_path = random.choice(image_path_list)

# 3. 从文件路径中获取图像分类的标签
image_class = random_image_path.parent.stem

# 4. 打开图片
img = Image.open(random_image_path)

# 5. 查看文件数据
print(f"Random image path: {random_image_path}")
print(f"Image class: {image_class}")
print(f"Image height: {img.height}")
print(f"Image width: {img.width}")

img

![[04-20240605004826175.webp|491]]

# 使用plt.imshow()查看
import numpy as np
import matplotlib.pyplot as plt

img_as_array = np.asarray(img)

plt.figure(figsize=(10, 7))
plt.imshow(img_as_array)
plt.title(f"Image class: {image_class} | Image shape: {img_as_array.shape} -> [height, width, color_channels]")
plt.axis(False);

![[04-20240605005408121.webp|469]]


3. 数据转换

在模型训练之前,无论是什么样的数据(图像、文本或者声音)都需要把它们转换成张量(tensors) 格式。

torchvision.transforms 包含了许多预构建的方法,用于格式化图像,将图像转换为张量,甚至可以对图像进行操作以进行数据增强,并通过 torchvision.transforms.Compose() 将所有步骤编译在一起。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


# 定义图像变换过程
data_transform = transforms.Compose([
    # 调整图片大小为64*64像素
    transforms.Resize(size=(64, 64)),
    # 随机碎片翻转图像
    transforms.RandomHorizontalFlip(p=0.5), # p指概率,这里是50%的发生概率
    # 将图像转化为Pytorch张量
    transforms.ToTensor()
])

transforms.ToTensor():

  • 输入格式:对于彩色图像(RGB),输入通常是形状为 (H, W, 3) 的 numpy 数组或 PIL 图像,其中 H 是高度,W 是宽度,3 表示颜色通道(红、绿、蓝)。
  • 输出格式:形状为 (C, H, W)的PyTorch 张量,其中 C 是颜色通道数(通常为 3),H 是高度,W 是宽度。
  • 示例:假设有一张 RGB 图像,原始大小为 256x256,转换后为形状为 (3, 256, 256) 的张量,其中 3 表示 RGB 通道。如果是灰度图像,转换为 (1, H, W),因为灰度图像只有一个通道。

4. 自定义数据集(Custom Dataset )

4.1 方法一:使用ImageFolder加载数据集

datasets.ImageFolder 是 PyTorch 中用于加载图像数据的一个工具,适合用于结构化组织的图像数据集。

它可以自动读取和加载按照文件夹结构组织的图像数据集,并且自动将文件夹名称作为类标签(类别)。

通常是创建ImageFolder 数据集对象,然后将其传递给 DataLoader 以便于批量加载和处理。

使用前需要准备:

  1. 数据集根目录 root_dir
  2. 图像变换 transform

这些我们都在前面的步骤创建好了。

# Use ImageFolder to create dataset(s)
from torchvision import datasets
train_data = datasets.ImageFolder(root=train_dir, # 目标文件夹
                                  transform=data_transform, # 图像转换
                                  target_transform=None) # transforms to perform on labels (if necessary)

test_data = datasets.ImageFolder(root=test_dir,
                                 transform=data_transform)

  

print(f"Train data:\n{train_data}\nTest data:\n{test_data}")

![[04-20240605014153743.webp]]

信息查看

完成之后,可以通过下面的指令查询信息:

class_names = train_data.classes
class_names

class_dict = train_data.class_to_idx
class_dict

len(train_data), len(test_data)

结果如下
在这里插入图片描述

img, label = train_data[0][0], train_data[0][1]

print(f"Image tensor:\n{img}")
print(f"Image shape: {img.shape}")
print(f"Image datatype: {img.dtype}")
print(f"Image label: {label}")
print(f"Label datatype: {type(label)}")

![[04-20240605014407429.webp]]

这是图像转换成张量之后的结果。张量形状是[3 , 64 , 64],分类标签是0, 对应class_to_idx中的标签类别。

张量转图片

那么又如何把这些数字,以图像形式展现呢?

我们需要把现在(C, H, W)(通道数、高度、宽度)的图像数据改变为 (H, W, C)

# 重排维度顺序
img_permute = img.permute(1, 2, 0)

print(f"Original shape: {img.shape} -> [color_channels, height, width]")
print(f"Image permute shape: {img_permute.shape} -> [height, width, color_channels]")

plt.figure(figsize=(10, 7))
plt.imshow(img.permute(1, 2, 0))
plt.axis("off")
plt.title(class_names[label], fontsize=14);

![[04-20240605015144564.webp]]

创建DataLoader
# Turn train and test Datasets into DataLoaders
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset=train_data, 
                              batch_size=1, 
                              num_workers=1,
                              shuffle=True)

test_dataloader = DataLoader(dataset=test_data, 
                             batch_size=1, 
                             num_workers=1, 
                             shuffle=False) 
                             
train_dataloader, test_dataloader

![[04-20240605015800201.webp]]

4.2 方法二:自定义Dataset

如果没有 torchvision.datasets.ImageFolder() 这样的预构建 Dataset 创建器,我们就需要自己构建一个Dataset。

导入所需的库
import os
import pathlib
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from typing import Tuple, Dict, List
类名辅助函数

编写一个辅助函数,该函数能够在给定目录路径的情况下创建类名列表和类名及其索引的字典。

# 定义搜索和创建类别字典的函数
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    # 1. 扫描文件路径目录,并获取类别标签名
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

    # 2. 没有找到类名时报错(这种可能是目录结构问题)
    if not classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}.")

    # 3. 将类名转换为一个数字标签字典,每个类对应一个标签
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

find_classes(train_dir)

![[04-20240605021116643.webp]]

复刻Imagefolder功能

这部分代码比较复杂,但是完成后,可以把这段代码单独存在.py格式的文件中,后续有需要时再调用。

分步考虑和复刻Imagefolder的功能:

  1. 继承 torch.utils.data.Dataset
  2. targ_dir 参数和 transform 参数初始化我们的子类。
  3. 创建几个属性用于存储图像路径、transform、classes 和 class_to_idx(来自find_classes() 函数)。
  4. 创建一个加载图像的函数,这可以使用 PIL 或 torchvision.io
  5. 重写 torch.utils.data.Dataset__len__ 方法以返回 Dataset 中的样本数量,建议这样做但不是必须的。这是为了可以调用 len(Dataset)
  6. 重写 torch.utils.data.Dataset__getitem__ 方法以返回 Dataset 中的单个样本,这是必须的。

这段代码在文章末尾补充了另一个简单例子解释。

# 编写一个自定义数据集类(继承自 torch.utils.data.Dataset)
from torch.utils.data import Dataset

# 1. 子类化 torch.utils.data.Dataset
class ImageFolderCustom(Dataset):
    
    # 2. 使用 targ_dir 和 transform(可选)参数初始化
    def __init__(self, targ_dir: str, transform=None) -> None:
        
        # 3. 创建类属性
        # 获取所有图像路径
        self.paths = list(pathlib.Path(targ_dir).glob("*/*.jpg")) # 注意:如果有 .png 或 .jpeg 格式的图片需要更新这行代码
        # 设置变换
        self.transform = transform
        # 创建 classes 和 class_to_idx 属性
        self.classes, self.class_to_idx = find_classes(targ_dir)

    # 4. 编写加载图像的函数
    def load_image(self, index: int) -> Image.Image:
        "通过路径打开图像并返回它。"
        image_path = self.paths[index]
        return Image.open(image_path) 
    
    # 5. 重写 __len__() 方法(可选,但推荐为 torch.utils.data.Dataset 的子类实现)
    def __len__(self) -> int:
        "返回样本总数。"
        return len(self.paths)
    
    # 6. 重写 __getitem__() 方法(为 torch.utils.data.Dataset 的子类必须实现)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "返回一个数据样本,数据和标签(X, y)。"
        img = self.load_image(index)
        class_name = self.paths[index].parent.name # 期望路径格式为 data_folder/class_name/image.jpeg
        class_idx = self.class_to_idx[class_name]

        # 如有必要进行变换
        if self.transform:
            return self.transform(img), class_idx # 返回数据和标签(X, y)
        else:
            return img, class_idx # 返回数据和标签(X, y)

接下来,我们另外创建一个图像转换transform过程,并且使用新创的ImageFolderCustom()创建数据集。

# 在训练集上使用数据增强
train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

# 测试集不需要使用增强技术。
test_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

train_data_custom = ImageFolderCustom(targ_dir=train_dir,
                                      transform=train_transforms)

test_data_custom = ImageFolderCustom(targ_dir=test_dir,
                                     transform=test_transforms)

train_data_custom, test_data_custom

![[04-20240605023304298.webp]]

检查自建dataset效果
![[04-20240605023239521.webp]]

# 对比验证:下面输出结果应该都为“True”
print((len(train_data_custom) == len(train_data)) & (len(test_data_custom) == len(test_data)))
print(train_data_custom.classes == train_data.classes)
print(train_data_custom.class_to_idx == train_data.class_to_idx)
可视化检查
# 定义一个函数,随机抽取Dataset中的图片展现
def display_random_images(dataset: torch.utils.data.dataset.Dataset,
                          classes: List[str] = None,
                          n: int = 10,
                          display_shape: bool = True,
                          seed: int = None):

    # 2. 设置n上限
    if n > 10:
        n = 10
        display_shape = False
        print(f"For display purposes, n shouldn't be larger than 10, setting to 10 and removing shape display.")

    # 3. 随机种子
    if seed:
        random.seed(seed)

    # 4. 随机抽取
    random_samples_idx = random.sample(range(len(dataset)), k=n)

    # 5. 图像设置
    plt.figure(figsize=(16, 8))
    
    # 6. 遍历
    for i, targ_sample in enumerate(random_samples_idx):
        targ_image, targ_label = dataset[targ_sample][0], dataset[targ_sample][1]

        # 7. 维度调整
        targ_image_adjust = targ_image.permute(1, 2, 0)

        # 8.画图
        plt.subplot(1, n, i+1)
        plt.imshow(targ_image_adjust)
        plt.axis("off")

        if classes:
            title = f"class: {classes[targ_label]}"
            if display_shape:
                title = title + f"\nshape: {targ_image_adjust.shape}"
        plt.title(title)

分别查看方法一和方法二的结果

display_random_images(train_data,
                      n=5,
                      classes=class_names,
                      seed=None)

![[04-20240605023916226.webp]]

display_random_images(train_data_custom,
                      n=12,
                      classes=class_names,
                      seed=None) 

![[04-20240605023942561.webp]]

创建DatasetLoader

和方法一是同样的方式,区别是数据集参数的不同

from torch.utils.data import DataLoader
train_dataloader_custom = DataLoader(dataset=train_data_custom, # 使用自定义的dataset
                                     batch_size=1,
                                     num_workers=0,
                                     shuffle=True) 

test_dataloader_custom = DataLoader(dataset=test_data_custom, # 使用自定义的dataset
                                    batch_size=1, 
                                    num_workers=0, 
                                    shuffle=False) 

train_dataloader_custom, test_dataloader_custom

![[04-20240605024136913.webp]]


5. 模型搭建(TinyVGG)

以下部分代码和上一篇笔记一样,点击链接跳转。

TinyVGG 是一个简单的卷积神经网络(CNN)模型,通常用于初学者的图像分类任务。它通常包含两个卷积层和两个全连接层,结构紧凑,适合于小规模数据集的实验和学习基础的深度学习概念。

class TinyVGG(nn.Module):
    """
    Model architecture copying TinyVGG from:
    https://poloclub.github.io/cnn-explainer/
    """

    def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
        super().__init__()
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape,
                      out_channels=hidden_units,
                      kernel_size=3, 
                      stride=1, 
                      padding=1), 
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,
                         stride=2) 
        )

        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_units*16*16,
                      out_features=output_shape)
        )

    def forward(self, x: torch.Tensor):
        x = self.conv_block_1(x)
        # print(x.shape)
        x = self.conv_block_2(x)
        # print(x.shape)
        x = self.classifier(x)
        # print(x.shape)
        return x

torch.manual_seed(42)

model_0 = TinyVGG(input_shape=3, 
                  hidden_units=10,
                  output_shape=len(train_data.classes)).to(device)

model_0

![[04-20240605025012948.webp]]


6. 模型训练

定义训练和测试循环

def train_step(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer):
    # 将模型设置为训练模式
    model.train()
    
    # 初始化训练损失和训练准确度
    train_loss, train_acc = 0, 0
    
    # 遍历数据加载器的数据批次
    for batch, (X, y) in enumerate(dataloader):
        # 将数据发送到目标设备
        X, y = X.to(device), y.to(device)

        # 1. 前向传播
        y_pred = model(X)

        # 2. 计算并累积损失
        loss = loss_fn(y_pred, y)
        train_loss += loss.item() 

        # 3. 优化器梯度清零
        optimizer.zero_grad()

        # 4. 反向传播
        loss.backward()

        # 5. 优化器步骤
        optimizer.step()

        # 计算并累积所有批次的准确度指标
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)

    # 计算平均损失和准确度
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

def test_step(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module):
    # 将模型设置为评估模式
    model.eval() 
    
    # 初始化测试损失和测试准确度
    test_loss, test_acc = 0, 0
    
    # 启用推理上下文管理器
    with torch.inference_mode():
        # 遍历数据加载器的批次
        for batch, (X, y) in enumerate(dataloader):
            # 将数据发送到目标设备
            X, y = X.to(device), y.to(device)
    
            # 1. 前向传播
            test_pred_logits = model(X)

            # 2. 计算并累积损失
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()
            
            # 计算并累积准确度
            test_pred_labels = test_pred_logits.argmax(dim=1)
            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
            
    # 计算平均损失和准确度
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc

再创建一个函数,循环执行模型的训练和测试步骤。

```python
from tqdm.auto import tqdm

# 1. 接收训练和测试步骤所需的各种参数
def train(model: torch.nn.Module, 
          train_dataloader: torch.utils.data.DataLoader, 
          test_dataloader: torch.utils.data.DataLoader, 
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int = 5):
    
    # 2. 创建空的结果字典
    results = {"train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }
    
    # 3. 循环执行指定次数的训练和测试步骤
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer)
        test_loss, test_acc = test_step(model=model,
            dataloader=test_dataloader,
            loss_fn=loss_fn)
        
        # 4. 打印当前状态
        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )

        # 5. 更新结果字典
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

    # 6. 返回填充结果字典
    return results

模型训练

torch.manual_seed(42)
torch.cuda.manual_seed(42)

# 设置迭代次数
NUM_EPOCHS = 5

# Recreate an instance of TinyVGG
model_0 = TinyVGG(input_shape=3, # number of color channels (3 for RGB)
                  hidden_units=10,
                  output_shape=len(train_data.classes)).to(device)

# 设置损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_0.parameters(), lr=0.001)

# 引入计时器并开始计时
from timeit import default_timer as timer
start_time = timer()

# 模型训练
model_0_results = train(model=model_0,
                        train_dataloader=train_dataloader_custom,
                        test_dataloader=test_dataloader_custom,
                        optimizer=optimizer,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS)

# 结束计时
end_time = timer()
print(f"Total training time: {end_time-start_time:.3f} seconds")

![[04-20240605030434455.webp]]


7. 模型评估

model_0_results.keys()

![[04-20240605030459676.webp]]

def plot_loss_curves(results: Dict[str, List[float]]):
    loss = results['train_loss']
    test_loss = results['test_loss']
	accuracy = results['train_acc']
    test_accuracy = results['test_acc']
    
    epochs = range(len(results['train_loss']))

    # 图片设置
    plt.figure(figsize=(15, 7))

    # 损失
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, label='train_loss')
    plt.plot(epochs, test_loss, label='test_loss')
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.legend()

    # 准确度
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy, label='train_accuracy')
    plt.plot(epochs, test_accuracy, label='test_accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.legend();

plot_loss_curves(model_0_results)

迭代5次结果如下:
![[04-20240605030616816.webp]]
把迭代次数增加到30次,可以明显看出模型有过拟合的趋势,毕竟我们的数据集非常小。
![[04-20240605030850944.webp]]


8. 模型预测(暂略)


四、补充说明:自定义数据集代码

这段代码来自:Datasets & DataLoaders — PyTorch Tutorials 2.3.0+cu121 documentation

代码定义了一个名为 CustomImageDataset 的类,它继承自 PyTorch 的 Dataset 类,主要用于处理图像数据集。这个类实现了自定义的数据集加载方式,特别是从一个带有注释文件(通常是CSV文件)和一个图像目录中读取数据。

自定义 Dataset 类必须包含三个函数:__init____len____getitem__

例如内置数据集FashionMNIST,其图像存储在目录img_dir中,其标签单独存储在 CSV 文件annotations_file中。

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

以下是对代码的详细解释:

1. 类的定义和初始化

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)  # 读取CSV文件,保存图像文件名和对应的标签
        self.img_dir = img_dir  # 图像文件所在的目录
        self.transform = transform  # 用于图像的转换操作
        self.target_transform = target_transform  # 用于标签的转换操作
  • __init__ 方法是类的构造函数,用于初始化类的实例。
  • annotations_file 是注释文件的路径,包含图像文件名及其对应的标签。
  • img_dir 是图像存放的目录路径。
  • transform 是一个可选的图像转换操作(如数据增强)。
  • target_transform 是一个可选的标签转换操作(如标签编码)。

2. 获取数据集的大小

    def __len__(self):
        return len(self.img_labels)  # 返回数据集的大小,即图像的数量
  • __len__ 方法返回数据集的大小,即注释文件中记录的图像数量。

3. 获取特定索引的数据项

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])  # 获取图像文件的完整路径
        image = read_image(img_path)  # 读取图像文件
        label = self.img_labels.iloc[idx, 1]  # 获取图像对应的标签
        if self.transform:
            image = self.transform(image)  # 对图像进行转换
        if self.target_transform:
            label = self.target_transform(label)  # 对标签进行转换
        return image, label  # 返回图像和标签
  • __getitem__ 方法用于获取指定索引 idx 处的数据项(图像和标签)。
  • img_path 拼接图像目录路径和图像文件名,得到图像的完整路径。
  • image 使用 read_image 函数读取图像文件。
  • label 获取图像对应的标签。
  • 如果提供了 transform,则对图像进行相应的转换。
  • 如果提供了 target_transform,则对标签进行相应的转换。
  • 最后返回图像和标签。

小结

这个类的设计使得我们可以将一个带有注释文件和图像目录的数据集加载到 PyTorch 中,方便进行训练和测试。通过支持可选的图像和标签转换,能够灵活地对数据进行预处理和增强。

而在4.2中,因为数据集是另外创建的,没有自带的注释文件。因此,需要另外创建类名辅助函数,从文件路径中获取类名标签。

从结构功能上看,4.2和补充例子的代码都是为了实现同样的功能。


参考资料

  • 04. PyTorch Custom Datasets - Zero to Mastery Learn PyTorch for Deep Learning
  • Datasets & DataLoaders — PyTorch Tutorials 2.3.0+cu121 documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1791375.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ChatGPT-4o抢先体验

速度很快,结果很智能,支持多模态输入输出,感兴趣联系作者

Unity开发Cosmos使用BNG Framework获取按键信息

Unity开发Cosmos使用BNG Framework获取按键信息 1、新建一个脚本&#xff0c;复制下面代码 using BNG;[Header("Input")]//[Tooltip("The key(s) to use to toggle locomotion type")]public List<ControllerBinding> locomotionToggleInput new …

SpringBoot+Vue实现前后端分离基本的环境搭建

目录 一、Vue项目的搭建 &#xff08;1&#xff09;基于vite创建vue项目 &#xff08;2&#xff09;引入elementplus &#xff08;3&#xff09;启动后端服务&#xff0c;并测试 二、SpringBoot项目的搭建 &#xff08;1&#xff09;通过idea创建SpringBoot项目 &#x…

每天五分钟深度学习PyTorch:Tensor张量的索引和切片

本文重点 有时候当我们拥有一个Tensor张量的时候,我们可能需要获取它某一维度的信息,那么此时我们就需要索引和切片的技术,它们可以帮助我们解决这些问题。 切片操作 a是四维的,然后默认是从第一维开始取,逗号表示取不同的维度 a[:2]表示第一维取0,1,后面三维取所有 …

一、大模型推理

https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/README_zh.md 安装 v7.1 https://github.com/hiyouga/LLaMA-Factory/releases/tag/v0.7.1 git clone --depth 1 https://github.com/hiyoug…

门面模式Api网关(SpringCloudGateway)

1. 前言 当前通过Eureka、Nacos解决了服务注册和服务发现问题&#xff0c;使用Spring Cloud LoadBalance解决了负载均衡的需求&#xff0c;同时借助OpenFeign实现了远程调用。然而&#xff0c;现有的微服务接口都直接对外暴露&#xff0c;容易被外部访问。为保障对外服务的安全…

问答机器人

怎样做自己的问答机器人&#xff1f; 根据我们提供的数据分析出问题的答案&#xff0c;我们并不需要训练自己的模型 微调模型 finetune&#xff0c;将语言模型调成另外的语言模型&#xff0c;更适合不同类型数据&#xff0c;运用finetune方法将模型变化 知识库模型 embedd…

alist配合onlyoffice 实现在线预览

alist配合onlyoffice 实现在线预览 文章目录 alist配合onlyoffice 实现在线预览一、安装onlyoffice二、增加view.html文件三、安装nginx&#xff0c;并增加conf配置文件四、alist预览配置增加 一、安装onlyoffice 我是采用docker安装&#xff0c;采用的版本是7.2&#xff0c; …

【因果推断python】16_工具变量2

目录 出生季度和教育对工资的影响 第一阶段 出生季度和教育对工资的影响 到目前为止&#xff0c;我们一直将这些工具视为一些神奇的变量 Z&#xff0c;它们具有仅通过干预变量影响结果的神奇特性。老实说&#xff0c;好的工具变量来之不易&#xff0c;我们不妨将它们视为奇迹…

Leetcode - 周赛400

目录 一&#xff0c;3168. 候诊室中的最少椅子数 二&#xff0c;3169. 无需开会的工作日 三&#xff0c;3170. 删除星号以后字典序最小的字符串 四&#xff0c;3171. 找到按位与最接近 K 的子数组 一&#xff0c;3168. 候诊室中的最少椅子数 本题是一道模拟题&#xff0c;直…

排序方法——《选择排序》

P. S.&#xff1a;以下代码均在VS2019环境下测试&#xff0c;不代表所有编译器均可通过。 P. S.&#xff1a;测试代码均未展示头文件stdio.h的声明&#xff0c;使用时请自行添加。 博主主页&#xff1a;Yan. yan.                        …

HCIP-Datacom-ARST自选题库_10_多种协议多选【24道题】

1.如图所示&#xff0c;PE1和PE2之间通过LoopbackO接口建立MP-BGP邻居关系&#xff0c;在配完成之后&#xff0c;发现CE1和CE2之间无法互相学习路由&#xff0c;下列哪些选项会造成该问题的出现? PE1或PE2未在BGP-VPNV4单播地址族视图使能邻居A PE1或PE2上的VPN实例参数配置错…

htb_solarlab

端口扫描 80,445 子域名扫描 木有 尝试使用smbclient连接445端口 Documents目录可查看 将Documents底下的文件下载到本地看看 xlsx文件里有一大串用户信息&#xff0c;包括username和password 先弄下来 不知道在哪登录&#xff0c;也没有子域名&#xff0c;于是返回进行全端…

第N4周:中文文本分类

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 一、预备知识 中文文本分类和英文文本分类都是文本分类&#xff0c;为什么要单独拎出来个中文文本分类呢&#xff1f; 在自然语言处理&#xff08;NLP&#x…

SickOS1.1 - Shellshock原理和利用过程精讲

SickOS1.1的另一种思路&#xff1b;用另一种方法打透这台机器 Nikto扫描 正常都是-h扫描&#xff1b;有代理就用-useproxy 指向的代理ip:端口 nikto -h 192.168.218.157 -useproxy 192.168.218.157:3128apache版本&#xff0c;有点低&#xff0c;现在都是2.4.54版本了&#x…

PDF批量加水印 与 去除水印实践

本文主要目标是尝试去除水印&#xff0c;但是为了准备测试数据&#xff0c;我们需要先准备好有水印的pdf测试文件。 注意&#xff1a;本文的去水印只针对文字悬浮图片悬浮两种特殊情况&#xff0c;即使是这两种情况也不代表一定都可以去除水印。 文章目录 批量添加透明图片水印…

OpenStreetMap部署(OSM)

参考&#xff1a;https://github.com/openstreetmap/openstreetmap-website/blob/master/DOCKER.md OpenStreeMap 部署 操作系统建议使用 Ubuntu 22 版本 安装 Docker # 更新软件包索引&#xff1a; sudo apt-get update # 允许APT使用HTTPS&#xff1a; sudo apt-get inst…

TypeScript的never类型的妙用

never类型介绍 在 TypeScript 中&#xff0c;"never" 是一个表示永远不会发生的值类型。 使用场景 "never" 类型通常用于以下几种情况&#xff1a; 1、函数返回类型&#xff1a;当一个函数永远不会返回任何值&#xff08;比如抛出异常或者无限循环&…

使用 MDC 实现日志链路跟踪,包教包会!

在微服务环境中&#xff0c;我们经常使用 Skywalking、Spring Cloud Sleut 等去实现整体请求链路的追踪&#xff0c;但是这个整体运维成本高&#xff0c;架构复杂&#xff0c;本次我们来使用 MDC 通过 Log 来实现一个轻量级的会话事务跟踪功能&#xff0c;需要的朋友可以参考一…

数据库与缓存⼀致性⽅案

数据库与缓存⼀致性⽅案 1、背景2、数据⼀致性⽅案设计3、数据⼀致性⽅案流程图4、关键代码4.1、 处理数据⼀致性的消息队列⼊⼝4.2、数据⼀致性配置的常量信息 1、背景 现有的业务场景下&#xff0c;都会涉及到数据库以及缓存双写的问题&#xff0c;⽆论是先删除缓存&#xf…