前言
本文记录一下如何简单自定义pytorch中Datasets,官方教程 文件层级目录如下:
images
annotations_file.csv
数据说明
image
文件夹中有需要训练的图片,annotations_file.csv
中有2列,分别为image_id
和label
,即图片名和其对应标签。
image_id label 1 风景 2 风景 3 风景 4 星空 5 星空 6 星空 7 人物 8 人物 9 人物
代码展示
导入必要包
import os
import torch
import pandas as pd
import numpy as np
from PIL import Image
from torch. utils. data import Dataset
from torchvision import datasets
import torchvision. transforms as T
自定义Datasets
自定义Datasets之前,首先我们需要准备两个信息:
class CustomImageDataset ( Dataset) :
def __init__ ( self, annotations_file, img_dir, transform= None , target_transform= None ) :
self. img_labels = pd. read_csv( annotations_file, encoding= "utf-8" )
self. img_dir = img_dir
self. transform = transform
self. target_transform = target_transform
def __len__ ( self) :
return len ( self. img_labels)
def __getitem__ ( self, idx) :
img_path = os. path. join( self. img_dir, str ( self. img_labels. iloc[ idx, 0 ] ) + '.jpg' )
image = Image. open ( img_path)
label = self. img_labels. iloc[ idx, 1 ]
if self. transform:
image = self. transform( image)
if self. target_transform:
label = self. target_transform( label)
return image, label
定义图片预处理方法
transform = { 'train' : T. Compose( [
T. Resize( ( 224 , 224 ) ) ,
T. AutoAugment( T. AutoAugmentPolicy. CIFAR10) ,
T. ToTensor( ) , ] ) }
annotations_file = '/kaggle/input/datasets-test/annotations_file.csv'
img_dir = '/kaggle/input/datasets-test/images'
train_data = CustomImageDataset( annotations_file = annotations_file , img_dir = img_dir, transform = transform[ 'train' ] )
我们可以使用len(train_data)
检查样本完整性,以及Datasets
定义正确性,这里输出9
,的确只有9张图片,正确无误。
使用DataLoaders加载数据
因为这里数据较少,所以设置batch_size = 2
,打乱数据shuffle = true
,不丢弃数据drop_last=False
,有关DataLoader
的更多操作可以参照官方API
from torch. utils. data import DataLoader
train_dataloader = DataLoader( train_data, batch_size= 2 , shuffle= True , drop_last= False )
使用iter()
函数和next()
函数,取1个batch
,检查数据
train_features, train_labels = next ( iter ( train_dataloader) )
取第1个batch
中的第1个图片,并将其可视化。 由于train_features[0]
的维度为(1,3,224,224)
,所以使用squeeze()
函数从数组中删除单维度条目,即把为1的维度去掉。再使用permute()
函数将维度变换(224,224,3)
,便于plt
绘图。
img = train_features[ 0 ] . squeeze( )
img = img. permute( 1 , 2 , 0 )
plt. imshow( np. asarray( img) )
plt. axis( 'off' )
plt. show( )
打印图片标签print(train_labels[0])
,输出'人物'
在Datasets中将字符串标签数值化
我们发现上面打印出的标签为字符串,如果我们想要将其数值化,只需要在Datasets
中__getitem__
部分改动一点
class CustomImageDataset ( Dataset) :
def __init__ ( self, annotations_file, img_dir, transform= None , target_transform= None ) :
self. img_labels = pd. read_csv( annotations_file, encoding= "utf-8" )
self. img_dir = img_dir
self. transform = transform
self. target_transform = target_transform
def __len__ ( self) :
return len ( self. img_labels)
def __getitem__ ( self, idx) :
img_path = os. path. join( self. img_dir, str ( self. img_labels. iloc[ idx, 0 ] ) + '.jpg' )
image = Image. open ( img_path)
data_category, data_class = pd. factorize( self. img_labels. iloc[ : , 1 ] )
label = data_category[ idx]
if self. transform:
image = self. transform( image)
if self. target_transform:
label = self. target_transform( label)
return image, label
这样就可以了,可以看到其实array
格式的数据也是可以读取的,只要保证idx
一致,且对应就可以。
划分训练集与验证集
根据前面的说明,其实训练集与验证集的划分就变的很简单了,只需要4个列表/数组,train_path
、train_label
、vaild_path
、vaild_label
分别表示训练集图片路径、标签、验证集图片路径、标签。DataSets
可以这样写:
class CustomImageDataset ( Dataset) :
def __init__ ( self, image_id, image_label, img_dir, transform= None , target_transform= None ) :
self. image_id = image_id
self. image_label = image_label
self. img_dir = img_dir
self. transform = transform
self. target_transform = target_transform
def __len__ ( self) :
return len ( self. img_labels)
def __getitem__ ( self, idx) :
img_path = os. path. join( self. img_dir, str ( self. image_id[ idx] ) + '.jpg' )
image = Image. open ( img_path)
label = self. image. label[ idx]
if self. transform:
image = self. transform( image)
if self. target_transform:
label = self. target_transform( label)
return image, label
transform = { 'train' : T. Compose( [ T. Resize( ( 224 , 224 ) ) , T. AutoAugment( T. AutoAugmentPolicy. CIFAR10) , T. ToTensor( ) , ] ) ,
'valid' : T. Compose( [ T. Resize( ( 224 , 224 ) ) , T. ToTensor( ) , ] ) }
train_data = CustomImageDataset( train_path, train_label, img_dir = './' , transform = transform[ 'train' ] )
valid_data = CustomImageDataset( valid_path, valid_label, img_dir = './' , transform = transform[ 'valid' ] )
train_dataloader = DataLoader( train_data, batch_size= 2 , shuffle= True , drop_last= False )
valid_dataloader = DataLoader( valid_data, batch_size= 2 , shuffle= True , drop_last= False )