ori_train = torchvision.datasets.ImageFolder(root= args.datadir + '/tiny-imagenet-200/train/', transform=transform)
#可以获取class_idx的映射
class_idx = ori_train.class_to_idx
val_annotations.txt中存储着每个图片对应的类别
获取验证集的标签
test_target = []
#读取val_annotations.txt
test_data_dir = "./data/tiny-imagenet-200/val"
with open(test_data_dir + "/val_annotations.txt", 'r') as file:
# 读取每一行并存储在数组中
lines = file.readlines()
# 输出每一行的数据
for line in lines:
content = line.strip().split("\t")
target = class_idx[content[1]]
test_target.append(target)
读取图片信息
ori_test_o = torchvision.datasets.ImageFolder(root= args.datadir + '/tiny-imagenet-200/val/', transform=transform)
自定义Dataset
ori_test = Imagenet_dataset(ori_test_o,test_target)
class Imagenet_dataset(torch.utils.data.Dataset):
def __init__(self, dataset, targets):
self.dataset = dataset
self.targets = targets
def __getitem__(self, idx):
img = self.dataset[idx][0]
label = self.targets[idx]
return (img, label)
def __len__(self):
return len(self.dataset)