Dataset类: 如何获取数据及标签。
Dataloader类:为之后的网络提供不同的数据形式。
1. 数据文件夹表示:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
# os.listdir——用于返回指定的文件夹包含的文件
self.img_path_list = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path_list[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path_list)
root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
img, label = train_dataset[200]
img.show()
2. 若分别有标签文件、图片文件:
例:标签文件放txt文件,文件名为图片编号,内容为图片的标签。
图片文件:
生成内容为标签的文件
import os
root_dir = "dataset/train"
target_dir = "ants_image"
img_path = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = "ants_label"
for i in img_path:
file_name = i.split('.')[0]
# 生成含有标签内容的文件(文件名为图片编号)
with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
f.write(label) # 在新创建的.txt文件中写入标签
Result: