📚博客主页:knighthood2001
✨公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️如遇文章付费,可先看看我公众号中是否发布免费文章❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!
今天来看一下ImageFolder
,官方代码如下:
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
描述
-
数据组织方式:数据集默认按照这种方式组织图像:在根目录下,每个类别(如"dog"和"cat")都有一个子目录,子目录下包含了该类的所有图像文件。
-
继承关系:这个类继承了torchvision.datasets.DatasetFolder,因此可以覆盖其中的方法来定制数据集。
参数
-
root (string): 数据集的根目录路径。
-
transform (callable, optional): 一个函数或变换,它接受一个PIL图像作为输入,并返回其变换后的版本。例如,可以使用
transforms.RandomCrop
来随机裁剪图像。 -
target_transform (callable, optional): 一个函数或变换,它接受目标(通常是类别标签)并对其进行变换。
-
loader (callable, optional): 一个函数,给定图像文件的路径,用于加载图像。
-
is_valid_file (callable, optional): 一个函数,接受一个图像文件的路径,并检查该文件是否是一个有效的文件(通常用于检查损坏的文件)。
属性
-
classes (list): 按字母顺序排列的类名列表。
-
class_to_idx (dict): 一个字典,其中键是类名,值是对应的类索引(通常是整数)。
-
imgs (list): 一个列表,其中包含(图像路径,类索引)的元组。
示例
假设你有如下的目录结构:
root/
dog/
xxx.png
xxy.png
cat/
123.png
456.png
使用此数据加载器时,你可以指定root为上述root目录的路径,然后数据加载器会读取每个类别下的图像文件,并为每个图像文件提供一个类别标签(基于其在哪个子目录下)。同时,你可以通过transform和target_transform参数来预处理图像和标签。
最后,数据加载器会将所有图像的路径和对应的类别索引存储在一个列表中,以便在后续的数据加载过程中使用。
因此,对于自己的数据集,你也需要把他变成这种架构,才能放进去,否则,就会报错。
transform和target_transform的区别
这两个参数在机器学习和深度学习中常用于数据预处理和增强,主要针对输入数据和目标数据(通常是标签)的处理。
-
transform (callable, optional):
- 作用:
transform
参数用于处理输入数据,通常是图像数据。它接受一个PIL图像(或其他格式的图像数据)作为输入,并返回对图像进行了某种变换之后的版本。 - 示例: 可以使用
transforms.RandomCrop
来随机裁剪图像,或者transforms.Resize
来调整图像大小。这些变换可以增加数据的多样性,提升模型的鲁棒性和泛化能力。
- 作用:
-
target_transform (callable, optional):
- 作用:
target_transform
参数用于处理目标数据,通常是类别标签。它接受一个目标数据(比如一个类别标签)作为输入,并对其进行某种变换。 - 示例: 在分类任务中,可以使用
torch.tensor
将类别标签转换为 PyTorch 的 Tensor 格式,或者进行其他必要的数据处理,以适应模型的输入要求。
- 作用:
区别:
transform
主要应用于输入数据,例如图像,目的是通过多样性增强数据集,以改善模型训练的效果。target_transform
则主要应用于目标数据,例如类别标签,目的是对标签进行必要的预处理或转换,使其适合模型的输入要求。
综上所述,transform
和 target_transform
在数据预处理中扮演着不同的角色,分别处理输入数据和目标数据,以确保模型能够有效地学习和泛化。