在train文件中:其中dataset是dataloader的方法,而dataloader等于CreateDataLoader。
所以我们跳到CreateDataLoader:
在CreateDataLoader中返回的是dataset_loader,是来自于CustomDatasetDataLoader。切调用了initialize。因为CustomDatasetDataLoader是一个类,所以initialize是调用了类里的方法。
我们到CustomDatasetDataLoader中:
首先我们定义了name方法,调用的时候返回:‘CustomDatasetDataLoader’。接着定义了initialize:
首先定义了数据集dataset:
dataset来自于AlignedDataset,且调用了initialize方法:
AlignedDataset继承了BaseDataset,BaseDataset是一个抽象类,方法需要被AlignedDataset重写:
在AlignedDataset中我们加载数据集,就是len,getitem,ini三要素。
首先根据标签的通道判断输入的是否是label。如果不是label,self.opt.label_nc == 0,dir_A=‘_A’,否则dir_A=‘_label’。然后获得dir_A的路径。
根据make_dataset函数将train_label的图片放在一个列表中:
train_img,train_instance也是同理。
接着在getitem中随机输入一个索引,根据索引去列表A里面取值,取出的路径通过image打开,打开后是Image格式。然后输入到get_params函数中。
首先获得图片的大小,接着判断裁剪的方式,默认是不裁剪,缩放到1024x512(read me有讲):
接着原始图片是2048x1024,(x,y)=rand((2048-512),(1024-512)),返回一个字典,'flip’对应的True或者False。
根据标签通道进行transfromer变换:
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
首先定义一个空列表,根据条件往里面添加:
首先添加__scale_width操作:transforms.Lambda将__scale_width封装到transform中
三个参数,图像尺寸,target_w=1024,method=bicubic。
首先ow=2048,oh=1024,w=1024,h=(1024*1024/2048)=512,然后将img插值到(1024,512)。剩下的变换可以根据情况添加。
最后转化为tensor并归一化用compose串起来。
经过一系列变换后乘以255,标签乘以255转化为灰度图。
对image进行相同的变化,但是不乘以255.如果使用实例图片,和标签进行相同的操作。
最后用一个字典储存下来:
这样CreateDataset结束,输出的dataset传入到dataloader中。通过调用CustomDatasetDataLoader.load_data即可获得进过dataloader之后的值。
最后回到train中: